🚀 Feature: Add and apply more stricter golangci-lint linting rules (#2286)

* golangci-lint: add and apply more stricter linting rules

* github: drop security workflow now that we use gosec linter inside golangci-lint

* github: use official golangci-lint CI linter

* Add editorconfig and gitattributes file
pull/2313/head
leonklingele 2023-01-27 09:01:37 +01:00 committed by GitHub
parent 7327a17951
commit 167a8b5e94
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
120 changed files with 1889 additions and 1321 deletions

8
.editorconfig Normal file
View File

@ -0,0 +1,8 @@
; This file is for unifying the coding style for different editors and IDEs.
; More information at http://editorconfig.org
; This style originates from https://github.com/fewagency/best-practices
root = true
[*]
charset = utf-8
end_of_line = lf

12
.gitattributes vendored Normal file
View File

@ -0,0 +1,12 @@
# Handle line endings automatically for files detected as text
# and leave all files detected as binary untouched.
* text=auto eol=lf
# Force batch scripts to always use CRLF line endings so that if a repo is accessed
# in Windows via a file share from Linux, the scripts will work.
*.{cmd,[cC][mM][dD]} text eol=crlf
*.{bat,[bB][aA][tT]} text eol=crlf
# Force bash scripts to always use LF line endings so that if a repo is accessed
# in Unix via a file share from Windows, the scripts will work.
*.sh text eol=lf

View File

@ -1,17 +1,28 @@
# Adapted from https://github.com/golangci/golangci-lint-action/blob/b56f6f529003f1c81d4d759be6bd5f10bf9a0fa0/README.md#how-to-use
name: golangci-lint
on: on:
push: push:
tags:
- v*
branches: branches:
- master - master
- main - main
pull_request: pull_request:
name: Linter permissions:
contents: read
jobs: jobs:
Golint: golangci:
name: lint
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Fetch Repository - uses: actions/setup-go@v3
uses: actions/checkout@v3
- name: Run Golint
uses: reviewdog/action-golangci-lint@v2
with: with:
golangci_lint_flags: "--tests=false" # NOTE: Keep this in sync with the version from go.mod
go-version: 1.19
- uses: actions/checkout@v3
- name: golangci-lint
uses: golangci/golangci-lint-action@v3
with:
# NOTE: Keep this in sync with the version from .golangci.yml
version: v1.50.1

View File

@ -1,17 +0,0 @@
on:
push:
branches:
- master
- main
pull_request:
name: Security
jobs:
Gosec:
runs-on: ubuntu-latest
steps:
- name: Fetch Repository
uses: actions/checkout@v3
- name: Run Gosec
uses: securego/gosec@master
with:
args: -exclude-dir=internal/*/ ./...

258
.golangci.yml Normal file
View File

@ -0,0 +1,258 @@
# Created based on v1.50.1
# NOTE: Keep this in sync with the version in .github/workflows/linter.yml
run:
modules-download-mode: readonly
skip-dirs-use-default: false
skip-dirs:
- internal # TODO: Also apply proper linting for internal dir
output:
sort-results: true
linters-settings:
# TODO: Eventually enable these checks
# depguard:
# include-go-root: true
# packages:
# - flag
# - io/ioutil
# - reflect
# - unsafe
# packages-with-error-message:
# - flag: '`flag` package is only allowed in main.go'
# - io/ioutil: '`io/ioutil` package is deprecated, use the `io` and `os` package instead'
# - reflect: '`reflect` package is dangerous to use'
# - unsafe: '`unsafe` package is dangerous to use'
errcheck:
check-type-assertions: true
check-blank: true
disable-default-exclusions: true
errchkjson:
report-no-exported: true
exhaustive:
default-signifies-exhaustive: true
forbidigo:
forbid:
- ^(fmt\.Print(|f|ln)|print|println)$
- 'http\.Default(Client|Transport)'
# TODO: Eventually enable these patterns
# - 'time\.Sleep'
# - 'panic'
gci:
sections:
- standard
- prefix(github.com/gofiber/fiber)
- default
- blank
- dot
custom-order: true
gocritic:
disabled-checks:
- ifElseChain
gofumpt:
module-path: github.com/gofiber/fiber
extra-rules: true
gosec:
config:
global:
audit: true
govet:
check-shadowing: true
enable-all: true
disable:
- shadow
- fieldalignment
grouper:
import-require-single-import: true
import-require-grouping: true
misspell:
locale: US
nolintlint:
require-explanation: true
require-specific: true
nonamedreturns:
report-error-in-defer: true
predeclared:
q: true
promlinter:
strict: true
revive:
enable-all-rules: true
rules:
# Provided by gomnd linter
- name: add-constant
disabled: true
- name: argument-limit
disabled: true
# Provided by bidichk
- name: banned-characters
disabled: true
- name: cognitive-complexity
disabled: true
- name: cyclomatic
disabled: true
- name: exported
disabled: true
- name: file-header
disabled: true
- name: function-result-limit
disabled: true
- name: function-length
disabled: true
- name: line-length-limit
disabled: true
- name: max-public-structs
disabled: true
- name: modifies-parameter
disabled: true
- name: nested-structs
disabled: true
- name: package-comments
disabled: true
stylecheck:
checks:
- all
- -ST1000
- -ST1020
- -ST1021
- -ST1022
tagliatelle:
case:
rules:
json: snake
#tenv:
# all: true
#unparam:
# check-exported: true
wrapcheck:
ignorePackageGlobs:
- github.com/gofiber/fiber/*
- github.com/valyala/fasthttp
issues:
exclude-use-default: false
linters:
enable:
- asasalint
- asciicheck
- bidichk
- bodyclose
- containedctx
- contextcheck
# - cyclop
# - deadcode
# - decorder
- depguard
- dogsled
# - dupl
# - dupword
- durationcheck
- errcheck
- errchkjson
- errname
- errorlint
- execinquery
- exhaustive
# - exhaustivestruct
# - exhaustruct
- exportloopref
- forbidigo
- forcetypeassert
# - funlen
- gci
- gochecknoglobals
- gochecknoinits
# - gocognit
- goconst
- gocritic
# - gocyclo
# - godot
# - godox
# - goerr113
- gofmt
- gofumpt
# - goheader
- goimports
# - golint
- gomnd
- gomoddirectives
# - gomodguard
- goprintffuncname
- gosec
- gosimple
- govet
- grouper
# - ifshort
# - importas
- ineffassign
# - interfacebloat
# - interfacer
# - ireturn
# - lll
- loggercheck
# - maintidx
# - makezero
# - maligned
- misspell
- nakedret
# - nestif
- nilerr
- nilnil
# - nlreturn
- noctx
- nolintlint
- nonamedreturns
# - nosnakecase
- nosprintfhostport
# - paralleltest # TODO: Enable once https://github.com/gofiber/fiber/issues/2254 is implemented
# - prealloc
- predeclared
- promlinter
- reassign
- revive
- rowserrcheck
# - scopelint
- sqlclosecheck
- staticcheck
# - structcheck
- stylecheck
- tagliatelle
# - tenv # TODO: Enable once we drop support for Go 1.16
# - testableexamples
# - testpackage # TODO: Enable once https://github.com/gofiber/fiber/issues/2252 is implemented
- thelper
# - tparallel # TODO: Enable once https://github.com/gofiber/fiber/issues/2254 is implemented
- typecheck
- unconvert
- unparam
- unused
- usestdlibvars
# - varcheck
# - varnamelen
- wastedassign
- whitespace
- wrapcheck
# - wsl

58
app.go
View File

@ -14,6 +14,7 @@ import (
"encoding/xml" "encoding/xml"
"errors" "errors"
"fmt" "fmt"
"log"
"net" "net"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
@ -24,6 +25,7 @@ import (
"time" "time"
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -306,7 +308,7 @@ type Config struct {
// FEATURE: v2.3.x // FEATURE: v2.3.x
// The router executes the same handler by default if StrictRouting or CaseSensitive is disabled. // The router executes the same handler by default if StrictRouting or CaseSensitive is disabled.
// Enabling RedirectFixedPath will change this behaviour into a client redirect to the original route path. // Enabling RedirectFixedPath will change this behavior into a client redirect to the original route path.
// Using the status code 301 for GET requests and 308 for all other request methods. // Using the status code 301 for GET requests and 308 for all other request methods.
// //
// Default: false // Default: false
@ -454,6 +456,8 @@ const (
) )
// HTTP methods enabled by default // HTTP methods enabled by default
//
//nolint:gochecknoglobals // Using a global var is fine here
var DefaultMethods = []string{ var DefaultMethods = []string{
MethodGet, MethodGet,
MethodHead, MethodHead,
@ -467,7 +471,7 @@ var DefaultMethods = []string{
} }
// DefaultErrorHandler that process return errors from handlers // DefaultErrorHandler that process return errors from handlers
var DefaultErrorHandler = func(c *Ctx, err error) error { func DefaultErrorHandler(c *Ctx, err error) error {
code := StatusInternalServerError code := StatusInternalServerError
var e *Error var e *Error
if errors.As(err, &e) { if errors.As(err, &e) {
@ -519,7 +523,7 @@ func New(config ...Config) *App {
if app.config.ETag { if app.config.ETag {
if !IsChild() { if !IsChild() {
fmt.Println("[Warning] Config.ETag is deprecated since v2.0.6, please use 'middleware/etag'.") log.Printf("[Warning] Config.ETag is deprecated since v2.0.6, please use 'middleware/etag'.\n")
} }
} }
@ -587,7 +591,7 @@ func (app *App) handleTrustedProxy(ipAddress string) {
if strings.Contains(ipAddress, "/") { if strings.Contains(ipAddress, "/") {
_, ipNet, err := net.ParseCIDR(ipAddress) _, ipNet, err := net.ParseCIDR(ipAddress)
if err != nil { if err != nil {
fmt.Printf("[Warning] IP range %q could not be parsed: %v\n", ipAddress, err) log.Printf("[Warning] IP range %q could not be parsed: %v\n", ipAddress, err)
} else { } else {
app.config.trustedProxyRanges = append(app.config.trustedProxyRanges, ipNet) app.config.trustedProxyRanges = append(app.config.trustedProxyRanges, ipNet)
} }
@ -822,7 +826,7 @@ func (app *App) Config() Config {
} }
// Handler returns the server handler. // Handler returns the server handler.
func (app *App) Handler() fasthttp.RequestHandler { func (app *App) Handler() fasthttp.RequestHandler { //revive:disable-line:confusing-naming // Having both a Handler() (uppercase) and a handler() (lowercase) is fine. TODO: Use nolint:revive directive instead. See https://github.com/golangci/golangci-lint/issues/3476
// prepare the server for the start // prepare the server for the start
app.startupProcess() app.startupProcess()
return app.handler return app.handler
@ -887,7 +891,7 @@ func (app *App) Hooks() *Hooks {
// Test is used for internal debugging by passing a *http.Request. // Test is used for internal debugging by passing a *http.Request.
// Timeout is optional and defaults to 1s, -1 will disable it completely. // Timeout is optional and defaults to 1s, -1 will disable it completely.
func (app *App) Test(req *http.Request, msTimeout ...int) (resp *http.Response, err error) { func (app *App) Test(req *http.Request, msTimeout ...int) (*http.Response, error) {
// Set timeout // Set timeout
timeout := 1000 timeout := 1000
if len(msTimeout) > 0 { if len(msTimeout) > 0 {
@ -902,15 +906,15 @@ func (app *App) Test(req *http.Request, msTimeout ...int) (resp *http.Response,
// Dump raw http request // Dump raw http request
dump, err := httputil.DumpRequest(req, true) dump, err := httputil.DumpRequest(req, true)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to dump request: %w", err)
} }
// Create test connection // Create test connection
conn := new(testConn) conn := new(testConn)
// Write raw http request // Write raw http request
if _, err = conn.r.Write(dump); err != nil { if _, err := conn.r.Write(dump); err != nil {
return nil, err return nil, fmt.Errorf("failed to write: %w", err)
} }
// prepare the server for the start // prepare the server for the start
app.startupProcess() app.startupProcess()
@ -943,7 +947,7 @@ func (app *App) Test(req *http.Request, msTimeout ...int) (resp *http.Response,
} }
// Check for errors // Check for errors
if err != nil && err != fasthttp.ErrGetOnly { if err != nil && !errors.Is(err, fasthttp.ErrGetOnly) {
return nil, err return nil, err
} }
@ -951,12 +955,17 @@ func (app *App) Test(req *http.Request, msTimeout ...int) (resp *http.Response,
buffer := bufio.NewReader(&conn.w) buffer := bufio.NewReader(&conn.w)
// Convert raw http response to *http.Response // Convert raw http response to *http.Response
return http.ReadResponse(buffer, req) res, err := http.ReadResponse(buffer, req)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
return res, nil
} }
type disableLogger struct{} type disableLogger struct{}
func (dl *disableLogger) Printf(_ string, _ ...interface{}) { func (*disableLogger) Printf(_ string, _ ...interface{}) {
// fmt.Println(fmt.Sprintf(format, args...)) // fmt.Println(fmt.Sprintf(format, args...))
} }
@ -967,7 +976,7 @@ func (app *App) init() *App {
// Only load templates if a view engine is specified // Only load templates if a view engine is specified
if app.config.Views != nil { if app.config.Views != nil {
if err := app.config.Views.Load(); err != nil { if err := app.config.Views.Load(); err != nil {
fmt.Printf("views: %v\n", err) log.Printf("[Warning]: failed to load views: %v\n", err)
} }
} }
@ -1039,25 +1048,30 @@ func (app *App) ErrorHandler(ctx *Ctx, err error) error {
// errors before calling the application's error handler method. // errors before calling the application's error handler method.
func (app *App) serverErrorHandler(fctx *fasthttp.RequestCtx, err error) { func (app *App) serverErrorHandler(fctx *fasthttp.RequestCtx, err error) {
c := app.AcquireCtx(fctx) c := app.AcquireCtx(fctx)
if _, ok := err.(*fasthttp.ErrSmallBuffer); ok { defer app.ReleaseCtx(c)
var errNetOP *net.OpError
switch {
case errors.As(err, new(*fasthttp.ErrSmallBuffer)):
err = ErrRequestHeaderFieldsTooLarge err = ErrRequestHeaderFieldsTooLarge
} else if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() { case errors.As(err, &errNetOP) && errNetOP.Timeout():
err = ErrRequestTimeout err = ErrRequestTimeout
} else if err == fasthttp.ErrBodyTooLarge { case errors.Is(err, fasthttp.ErrBodyTooLarge):
err = ErrRequestEntityTooLarge err = ErrRequestEntityTooLarge
} else if err == fasthttp.ErrGetOnly { case errors.Is(err, fasthttp.ErrGetOnly):
err = ErrMethodNotAllowed err = ErrMethodNotAllowed
} else if strings.Contains(err.Error(), "timeout") { case strings.Contains(err.Error(), "timeout"):
err = ErrRequestTimeout err = ErrRequestTimeout
} else { default:
err = NewError(StatusBadRequest, err.Error()) err = NewError(StatusBadRequest, err.Error())
} }
if catch := app.ErrorHandler(c, err); catch != nil { if catch := app.ErrorHandler(c, err); catch != nil {
_ = c.SendStatus(StatusInternalServerError) log.Printf("serverErrorHandler: failed to call ErrorHandler: %v\n", catch)
_ = c.SendStatus(StatusInternalServerError) //nolint:errcheck // It is fine to ignore the error here
return
} }
app.ReleaseCtx(c)
} }
// startupProcess Is the method which executes all the necessary processes just before the start of the server. // startupProcess Is the method which executes all the necessary processes just before the start of the server.

View File

@ -2,6 +2,7 @@
// 🤖 Github Repository: https://github.com/gofiber/fiber // 🤖 Github Repository: https://github.com/gofiber/fiber
// 📌 API Documentation: https://docs.gofiber.io // 📌 API Documentation: https://docs.gofiber.io
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
package fiber package fiber
import ( import (
@ -23,15 +24,16 @@ import (
"time" "time"
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/fasthttputil" "github.com/valyala/fasthttp/fasthttputil"
) )
var testEmptyHandler = func(c *Ctx) error { func testEmptyHandler(_ *Ctx) error {
return nil return nil
} }
func testStatus200(t *testing.T, app *App, url string, method string) { func testStatus200(t *testing.T, app *App, url, method string) {
t.Helper() t.Helper()
req := httptest.NewRequest(method, url, nil) req := httptest.NewRequest(method, url, nil)
@ -42,6 +44,8 @@ func testStatus200(t *testing.T, app *App, url string, method string) {
} }
func testErrorResponse(t *testing.T, err error, resp *http.Response, expectedBodyError string) { func testErrorResponse(t *testing.T, err error, resp *http.Response, expectedBodyError string) {
t.Helper()
utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, 500, resp.StatusCode, "Status code") utils.AssertEqual(t, 500, resp.StatusCode, "Status code")
@ -140,7 +144,6 @@ func Test_App_ServerErrorHandler_SmallReadBuffer(t *testing.T) {
logHeaderSlice := make([]string, 5000) logHeaderSlice := make([]string, 5000)
request.Header.Set("Very-Long-Header", strings.Join(logHeaderSlice, "-")) request.Header.Set("Very-Long-Header", strings.Join(logHeaderSlice, "-"))
_, err := app.Test(request) _, err := app.Test(request)
if err == nil { if err == nil {
t.Error("Expect an error at app.Test(request)") t.Error("Expect an error at app.Test(request)")
} }
@ -470,7 +473,6 @@ func Test_App_Use_MultiplePrefix(t *testing.T) {
body, err = io.ReadAll(resp.Body) body, err = io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "/test/doe", string(body)) utils.AssertEqual(t, "/test/doe", string(body))
} }
func Test_App_Use_StrictRouting(t *testing.T) { func Test_App_Use_StrictRouting(t *testing.T) {
@ -515,7 +517,7 @@ func Test_App_Add_Method_Test(t *testing.T) {
} }
}() }()
methods := append(DefaultMethods, "JOHN") methods := append(DefaultMethods, "JOHN") //nolint:gocritic // We want a new slice here
app := New(Config{ app := New(Config{
RequestMethods: methods, RequestMethods: methods,
}) })
@ -780,7 +782,7 @@ func Test_App_ShutdownWithTimeout(t *testing.T) {
case <-timer.C: case <-timer.C:
t.Fatal("idle connections not closed on shutdown") t.Fatal("idle connections not closed on shutdown")
case err := <-shutdownErr: case err := <-shutdownErr:
if err == nil || err != context.DeadlineExceeded { if err == nil || !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("unexpected err %v. Expecting %v", err, context.DeadlineExceeded) t.Fatalf("unexpected err %v. Expecting %v", err, context.DeadlineExceeded)
} }
} }
@ -852,7 +854,7 @@ func Test_App_Static_MaxAge(t *testing.T) {
app.Static("/", "./.github", Static{MaxAge: 100}) app.Static("/", "./.github", Static{MaxAge: 100})
resp, err := app.Test(httptest.NewRequest("GET", "/index.html", nil)) resp, err := app.Test(httptest.NewRequest(MethodGet, "/index.html", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, 200, resp.StatusCode, "Status code") utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
utils.AssertEqual(t, false, resp.Header.Get(HeaderContentLength) == "") utils.AssertEqual(t, false, resp.Header.Get(HeaderContentLength) == "")
@ -866,19 +868,19 @@ func Test_App_Static_Custom_CacheControl(t *testing.T) {
app := New() app := New()
app.Static("/", "./.github", Static{ModifyResponse: func(c *Ctx) error { app.Static("/", "./.github", Static{ModifyResponse: func(c *Ctx) error {
if strings.Contains(string(c.GetRespHeader("Content-Type")), "text/html") { if strings.Contains(c.GetRespHeader("Content-Type"), "text/html") {
c.Response().Header.Set("Cache-Control", "no-cache, no-store, must-revalidate") c.Response().Header.Set("Cache-Control", "no-cache, no-store, must-revalidate")
} }
return nil return nil
}}) }})
resp, err := app.Test(httptest.NewRequest("GET", "/index.html", nil)) resp, err := app.Test(httptest.NewRequest(MethodGet, "/index.html", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, "no-cache, no-store, must-revalidate", resp.Header.Get(HeaderCacheControl), "CacheControl Control") utils.AssertEqual(t, "no-cache, no-store, must-revalidate", resp.Header.Get(HeaderCacheControl), "CacheControl Control")
normal_resp, normal_err := app.Test(httptest.NewRequest("GET", "/config.yml", nil)) respNormal, errNormal := app.Test(httptest.NewRequest(MethodGet, "/config.yml", nil))
utils.AssertEqual(t, nil, normal_err, "app.Test(req)") utils.AssertEqual(t, nil, errNormal, "app.Test(req)")
utils.AssertEqual(t, "", normal_resp.Header.Get(HeaderCacheControl), "CacheControl Control") utils.AssertEqual(t, "", respNormal.Header.Get(HeaderCacheControl), "CacheControl Control")
} }
// go test -run Test_App_Static_Download // go test -run Test_App_Static_Download
@ -890,7 +892,7 @@ func Test_App_Static_Download(t *testing.T) {
app.Static("/fiber.png", "./.github/testdata/fs/img/fiber.png", Static{Download: true}) app.Static("/fiber.png", "./.github/testdata/fs/img/fiber.png", Static{Download: true})
resp, err := app.Test(httptest.NewRequest("GET", "/fiber.png", nil)) resp, err := app.Test(httptest.NewRequest(MethodGet, "/fiber.png", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, 200, resp.StatusCode, "Status code") utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
utils.AssertEqual(t, false, resp.Header.Get(HeaderContentLength) == "") utils.AssertEqual(t, false, resp.Header.Get(HeaderContentLength) == "")
@ -1067,7 +1069,7 @@ func Test_App_Static_Next(t *testing.T) {
t.Run("app.Static is skipped: invoking Get handler", func(t *testing.T) { t.Run("app.Static is skipped: invoking Get handler", func(t *testing.T) {
t.Parallel() t.Parallel()
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(MethodGet, "/", nil)
req.Header.Set("X-Custom-Header", "skip") req.Header.Set("X-Custom-Header", "skip")
resp, err := app.Test(req) resp, err := app.Test(req)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
@ -1082,7 +1084,7 @@ func Test_App_Static_Next(t *testing.T) {
t.Run("app.Static is not skipped: serving index.html", func(t *testing.T) { t.Run("app.Static is not skipped: serving index.html", func(t *testing.T) {
t.Parallel() t.Parallel()
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(MethodGet, "/", nil)
req.Header.Set("X-Custom-Header", "don't skip") req.Header.Set("X-Custom-Header", "don't skip")
resp, err := app.Test(req) resp, err := app.Test(req)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
@ -1379,7 +1381,7 @@ func Test_Test_DumpError(t *testing.T) {
resp, err := app.Test(httptest.NewRequest(MethodGet, "/", errorReader(0))) resp, err := app.Test(httptest.NewRequest(MethodGet, "/", errorReader(0)))
utils.AssertEqual(t, true, resp == nil) utils.AssertEqual(t, true, resp == nil)
utils.AssertEqual(t, "errorReader", err.Error()) utils.AssertEqual(t, "failed to dump request: errorReader", err.Error())
} }
// go test -run Test_App_Handler // go test -run Test_App_Handler
@ -1393,7 +1395,7 @@ type invalidView struct{}
func (invalidView) Load() error { return errors.New("invalid view") } func (invalidView) Load() error { return errors.New("invalid view") }
func (i invalidView) Render(io.Writer, string, interface{}, ...string) error { panic("implement me") } func (invalidView) Render(io.Writer, string, interface{}, ...string) error { panic("implement me") }
// go test -run Test_App_Init_Error_View // go test -run Test_App_Init_Error_View
func Test_App_Init_Error_View(t *testing.T) { func Test_App_Init_Error_View(t *testing.T) {
@ -1405,7 +1407,9 @@ func Test_App_Init_Error_View(t *testing.T) {
utils.AssertEqual(t, "implement me", fmt.Sprintf("%v", err)) utils.AssertEqual(t, "implement me", fmt.Sprintf("%v", err))
} }
}() }()
_ = app.config.Views.Render(nil, "", nil)
err := app.config.Views.Render(nil, "", nil)
utils.AssertEqual(t, nil, err)
} }
// go test -run Test_App_Stack // go test -run Test_App_Stack
@ -1535,11 +1539,12 @@ func Test_App_SmallReadBuffer(t *testing.T) {
go func() { go func() {
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
resp, err := http.Get("http://127.0.0.1:4006/small-read-buffer") req, err := http.NewRequestWithContext(context.Background(), MethodGet, "http://127.0.0.1:4006/small-read-buffer", http.NoBody)
if resp != nil {
utils.AssertEqual(t, 431, resp.StatusCode)
}
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
var client http.Client
resp, err := client.Do(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 431, resp.StatusCode)
utils.AssertEqual(t, nil, app.Shutdown()) utils.AssertEqual(t, nil, app.Shutdown())
}() }()
@ -1572,13 +1577,13 @@ func Test_App_New_Test_Parallel(t *testing.T) {
t.Run("Test_App_New_Test_Parallel_1", func(t *testing.T) { t.Run("Test_App_New_Test_Parallel_1", func(t *testing.T) {
t.Parallel() t.Parallel()
app := New(Config{Immutable: true}) app := New(Config{Immutable: true})
_, err := app.Test(httptest.NewRequest("GET", "/", nil)) _, err := app.Test(httptest.NewRequest(MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
}) })
t.Run("Test_App_New_Test_Parallel_2", func(t *testing.T) { t.Run("Test_App_New_Test_Parallel_2", func(t *testing.T) {
t.Parallel() t.Parallel()
app := New(Config{Immutable: true}) app := New(Config{Immutable: true})
_, err := app.Test(httptest.NewRequest("GET", "/", nil)) _, err := app.Test(httptest.NewRequest(MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
}) })
} }
@ -1591,7 +1596,7 @@ func Test_App_ReadBodyStream(t *testing.T) {
return c.SendString(fmt.Sprintf("%v %s", c.Request().IsBodyStream(), c.Body())) return c.SendString(fmt.Sprintf("%v %s", c.Request().IsBodyStream(), c.Body()))
}) })
testString := "this is a test" testString := "this is a test"
resp, err := app.Test(httptest.NewRequest("POST", "/", bytes.NewBufferString(testString))) resp, err := app.Test(httptest.NewRequest(MethodPost, "/", bytes.NewBufferString(testString)))
utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, nil, err, "app.Test(req)")
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err, "io.ReadAll(resp.Body)") utils.AssertEqual(t, nil, err, "io.ReadAll(resp.Body)")
@ -1615,12 +1620,12 @@ func Test_App_DisablePreParseMultipartForm(t *testing.T) {
} }
file, err := mpf.File["test"][0].Open() file, err := mpf.File["test"][0].Open()
if err != nil { if err != nil {
return err return fmt.Errorf("failed to open: %w", err)
} }
buffer := make([]byte, len(testString)) buffer := make([]byte, len(testString))
n, err := file.Read(buffer) n, err := file.Read(buffer)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to read: %w", err)
} }
if n != len(testString) { if n != len(testString) {
return fmt.Errorf("bad read length") return fmt.Errorf("bad read length")
@ -1636,7 +1641,7 @@ func Test_App_DisablePreParseMultipartForm(t *testing.T) {
utils.AssertEqual(t, len(testString), n, "writer n") utils.AssertEqual(t, len(testString), n, "writer n")
utils.AssertEqual(t, nil, w.Close(), "w.Close()") utils.AssertEqual(t, nil, w.Close(), "w.Close()")
req := httptest.NewRequest("POST", "/", b) req := httptest.NewRequest(MethodPost, "/", b)
req.Header.Set("Content-Type", w.FormDataContentType()) req.Header.Set("Content-Type", w.FormDataContentType())
resp, err := app.Test(req) resp, err := app.Test(req)
utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, nil, err, "app.Test(req)")
@ -1659,7 +1664,7 @@ func Test_App_Test_no_timeout_infinitely(t *testing.T) {
return nil return nil
}) })
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) req := httptest.NewRequest(MethodGet, "/", http.NoBody)
_, err = app.Test(req, -1) _, err = app.Test(req, -1)
}() }()
@ -1696,7 +1701,7 @@ func Test_App_SetTLSHandler(t *testing.T) {
func Test_App_AddCustomRequestMethod(t *testing.T) { func Test_App_AddCustomRequestMethod(t *testing.T) {
t.Parallel() t.Parallel()
methods := append(DefaultMethods, "TEST") methods := append(DefaultMethods, "TEST") //nolint:gocritic // We want a new slice here
app := New(Config{ app := New(Config{
RequestMethods: methods, RequestMethods: methods,
}) })

View File

@ -15,6 +15,7 @@ import (
"time" "time"
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -50,6 +51,7 @@ type Args = fasthttp.Args
// Copy from fasthttp // Copy from fasthttp
type RetryIfFunc = fasthttp.RetryIfFunc type RetryIfFunc = fasthttp.RetryIfFunc
//nolint:gochecknoglobals // TODO: Do not use a global var here
var defaultClient Client var defaultClient Client
// Client implements http client. // Client implements http client.
@ -186,11 +188,11 @@ func (a *Agent) Parse() error {
uri := a.req.URI() uri := a.req.URI()
isTLS := false var isTLS bool
scheme := uri.Scheme() scheme := uri.Scheme()
if bytes.Equal(scheme, strHTTPS) { if bytes.Equal(scheme, []byte(schemeHTTPS)) {
isTLS = true isTLS = true
} else if !bytes.Equal(scheme, strHTTP) { } else if !bytes.Equal(scheme, []byte(schemeHTTP)) {
return fmt.Errorf("unsupported protocol %q. http and https are supported", scheme) return fmt.Errorf("unsupported protocol %q. http and https are supported", scheme)
} }
@ -241,7 +243,7 @@ func (a *Agent) SetBytesV(k string, v []byte) *Agent {
// SetBytesKV sets the given 'key: value' header. // SetBytesKV sets the given 'key: value' header.
// //
// Use AddBytesKV for setting multiple header values under the same key. // Use AddBytesKV for setting multiple header values under the same key.
func (a *Agent) SetBytesKV(k []byte, v []byte) *Agent { func (a *Agent) SetBytesKV(k, v []byte) *Agent {
a.req.Header.SetBytesKV(k, v) a.req.Header.SetBytesKV(k, v)
return a return a
@ -281,7 +283,7 @@ func (a *Agent) AddBytesV(k string, v []byte) *Agent {
// //
// Multiple headers with the same key may be added with this function. // Multiple headers with the same key may be added with this function.
// Use SetBytesKV for setting a single header for the given key. // Use SetBytesKV for setting a single header for the given key.
func (a *Agent) AddBytesKV(k []byte, v []byte) *Agent { func (a *Agent) AddBytesKV(k, v []byte) *Agent {
a.req.Header.AddBytesKV(k, v) a.req.Header.AddBytesKV(k, v)
return a return a
@ -652,10 +654,8 @@ func (a *Agent) Reuse() *Agent {
// certificate chain and host name. // certificate chain and host name.
func (a *Agent) InsecureSkipVerify() *Agent { func (a *Agent) InsecureSkipVerify() *Agent {
if a.HostClient.TLSConfig == nil { if a.HostClient.TLSConfig == nil {
/* #nosec G402 */ a.HostClient.TLSConfig = &tls.Config{InsecureSkipVerify: true} //nolint:gosec // We explicitly let the user set insecure mode here
a.HostClient.TLSConfig = &tls.Config{InsecureSkipVerify: true} // #nosec G402
} else { } else {
/* #nosec G402 */
a.HostClient.TLSConfig.InsecureSkipVerify = true a.HostClient.TLSConfig.InsecureSkipVerify = true
} }
@ -728,14 +728,14 @@ func (a *Agent) RetryIf(retryIf RetryIfFunc) *Agent {
// Bytes returns the status code, bytes body and errors of url. // Bytes returns the status code, bytes body and errors of url.
// //
// it's not safe to use Agent after calling [Agent.Bytes] // it's not safe to use Agent after calling [Agent.Bytes]
func (a *Agent) Bytes() (code int, body []byte, errs []error) { func (a *Agent) Bytes() (int, []byte, []error) {
defer a.release() defer a.release()
return a.bytes() return a.bytes()
} }
func (a *Agent) bytes() (code int, body []byte, errs []error) { func (a *Agent) bytes() (code int, body []byte, errs []error) { //nolint:nonamedreturns,revive // We want to overwrite the body in a deferred func. TODO: Check if we really need to do this. We eventually want to get rid of all named returns.
if errs = append(errs, a.errs...); len(errs) > 0 { if errs = append(errs, a.errs...); len(errs) > 0 {
return return code, body, errs
} }
var ( var (
@ -760,7 +760,7 @@ func (a *Agent) bytes() (code int, body []byte, errs []error) {
code = resp.StatusCode() code = resp.StatusCode()
} }
body = append(a.dest, resp.Body()...) body = append(a.dest, resp.Body()...) //nolint:gocritic // We want to append to the returned slice here
if nilResp { if nilResp {
ReleaseResponse(resp) ReleaseResponse(resp)
@ -770,25 +770,25 @@ func (a *Agent) bytes() (code int, body []byte, errs []error) {
if a.timeout > 0 { if a.timeout > 0 {
if err := a.HostClient.DoTimeout(req, resp, a.timeout); err != nil { if err := a.HostClient.DoTimeout(req, resp, a.timeout); err != nil {
errs = append(errs, err) errs = append(errs, err)
return return code, body, errs
} }
} else if a.maxRedirectsCount > 0 && (string(req.Header.Method()) == MethodGet || string(req.Header.Method()) == MethodHead) { } else if a.maxRedirectsCount > 0 && (string(req.Header.Method()) == MethodGet || string(req.Header.Method()) == MethodHead) {
if err := a.HostClient.DoRedirects(req, resp, a.maxRedirectsCount); err != nil { if err := a.HostClient.DoRedirects(req, resp, a.maxRedirectsCount); err != nil {
errs = append(errs, err) errs = append(errs, err)
return return code, body, errs
} }
} else if err := a.HostClient.Do(req, resp); err != nil { } else if err := a.HostClient.Do(req, resp); err != nil {
errs = append(errs, err) errs = append(errs, err)
} }
return return code, body, errs
} }
func printDebugInfo(req *Request, resp *Response, w io.Writer) { func printDebugInfo(req *Request, resp *Response, w io.Writer) {
msg := fmt.Sprintf("Connected to %s(%s)\r\n\r\n", req.URI().Host(), resp.RemoteAddr()) msg := fmt.Sprintf("Connected to %s(%s)\r\n\r\n", req.URI().Host(), resp.RemoteAddr())
_, _ = w.Write(utils.UnsafeBytes(msg)) _, _ = w.Write(utils.UnsafeBytes(msg)) //nolint:errcheck // This will never fail
_, _ = req.WriteTo(w) _, _ = req.WriteTo(w) //nolint:errcheck // This will never fail
_, _ = resp.WriteTo(w) _, _ = resp.WriteTo(w) //nolint:errcheck // This will never fail
} }
// String returns the status code, string body and errors of url. // String returns the status code, string body and errors of url.
@ -797,6 +797,7 @@ func printDebugInfo(req *Request, resp *Response, w io.Writer) {
func (a *Agent) String() (int, string, []error) { func (a *Agent) String() (int, string, []error) {
defer a.release() defer a.release()
code, body, errs := a.bytes() code, body, errs := a.bytes()
// TODO: There might be a data race here on body. Maybe use utils.CopyBytes on it?
return code, utils.UnsafeString(body), errs return code, utils.UnsafeString(body), errs
} }
@ -805,12 +806,15 @@ func (a *Agent) String() (int, string, []error) {
// And bytes body will be unmarshalled to given v. // And bytes body will be unmarshalled to given v.
// //
// it's not safe to use Agent after calling [Agent.Struct] // it's not safe to use Agent after calling [Agent.Struct]
func (a *Agent) Struct(v interface{}) (code int, body []byte, errs []error) { func (a *Agent) Struct(v interface{}) (int, []byte, []error) {
defer a.release() defer a.release()
if code, body, errs = a.bytes(); len(errs) > 0 {
return code, body, errs := a.bytes()
if len(errs) > 0 {
return code, body, errs
} }
// TODO: This should only be done once
if a.jsonDecoder == nil { if a.jsonDecoder == nil {
a.jsonDecoder = json.Unmarshal a.jsonDecoder = json.Unmarshal
} }
@ -819,7 +823,7 @@ func (a *Agent) Struct(v interface{}) (code int, body []byte, errs []error) {
errs = append(errs, err) errs = append(errs, err)
} }
return return code, body, errs
} }
func (a *Agent) release() { func (a *Agent) release() {
@ -855,6 +859,7 @@ func (a *Agent) reset() {
a.formFiles = a.formFiles[:0] a.formFiles = a.formFiles[:0]
} }
//nolint:gochecknoglobals // TODO: Do not use global vars here
var ( var (
clientPool sync.Pool clientPool sync.Pool
agentPool = sync.Pool{ agentPool = sync.Pool{
@ -877,7 +882,11 @@ func AcquireClient() *Client {
if v == nil { if v == nil {
return &Client{} return &Client{}
} }
return v.(*Client) c, ok := v.(*Client)
if !ok {
panic(fmt.Errorf("failed to type-assert to *Client"))
}
return c
} }
// ReleaseClient returns c acquired via AcquireClient to client pool. // ReleaseClient returns c acquired via AcquireClient to client pool.
@ -899,7 +908,11 @@ func ReleaseClient(c *Client) {
// no longer needed. This allows Agent recycling, reduces GC pressure // no longer needed. This allows Agent recycling, reduces GC pressure
// and usually improves performance. // and usually improves performance.
func AcquireAgent() *Agent { func AcquireAgent() *Agent {
return agentPool.Get().(*Agent) a, ok := agentPool.Get().(*Agent)
if !ok {
panic(fmt.Errorf("failed to type-assert to *Agent"))
}
return a
} }
// ReleaseAgent returns a acquired via AcquireAgent to Agent pool. // ReleaseAgent returns a acquired via AcquireAgent to Agent pool.
@ -922,7 +935,11 @@ func AcquireResponse() *Response {
if v == nil { if v == nil {
return &Response{} return &Response{}
} }
return v.(*Response) r, ok := v.(*Response)
if !ok {
panic(fmt.Errorf("failed to type-assert to *Response"))
}
return r
} }
// ReleaseResponse return resp acquired via AcquireResponse to response pool. // ReleaseResponse return resp acquired via AcquireResponse to response pool.
@ -945,7 +962,11 @@ func AcquireArgs() *Args {
if v == nil { if v == nil {
return &Args{} return &Args{}
} }
return v.(*Args) a, ok := v.(*Args)
if !ok {
panic(fmt.Errorf("failed to type-assert to *Args"))
}
return a
} }
// ReleaseArgs returns the object acquired via AcquireArgs to the pool. // ReleaseArgs returns the object acquired via AcquireArgs to the pool.
@ -966,7 +987,11 @@ func AcquireFormFile() *FormFile {
if v == nil { if v == nil {
return &FormFile{} return &FormFile{}
} }
return v.(*FormFile) ff, ok := v.(*FormFile)
if !ok {
panic(fmt.Errorf("failed to type-assert to *FormFile"))
}
return ff
} }
// ReleaseFormFile returns the object acquired via AcquireFormFile to the pool. // ReleaseFormFile returns the object acquired via AcquireFormFile to the pool.
@ -981,9 +1006,7 @@ func ReleaseFormFile(ff *FormFile) {
formFilePool.Put(ff) formFilePool.Put(ff)
} }
var ( const (
strHTTP = []byte("http")
strHTTPS = []byte("https")
defaultUserAgent = "fiber" defaultUserAgent = "fiber"
) )

View File

@ -1,3 +1,4 @@
//nolint:wrapcheck // We must not wrap errors in tests
package fiber package fiber
import ( import (
@ -19,6 +20,7 @@ import (
"github.com/gofiber/fiber/v2/internal/tlstest" "github.com/gofiber/fiber/v2/internal/tlstest"
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp/fasthttputil" "github.com/valyala/fasthttp/fasthttputil"
) )
@ -295,8 +297,10 @@ func Test_Client_Agent_Set_Or_Add_Headers(t *testing.T) {
handler := func(c *Ctx) error { handler := func(c *Ctx) error {
c.Request().Header.VisitAll(func(key, value []byte) { c.Request().Header.VisitAll(func(key, value []byte) {
if k := string(key); k == "K1" || k == "K2" { if k := string(key); k == "K1" || k == "K2" {
_, _ = c.Write(key) _, err := c.Write(key)
_, _ = c.Write(value) utils.AssertEqual(t, nil, err)
_, err = c.Write(value)
utils.AssertEqual(t, nil, err)
} }
}) })
return nil return nil
@ -581,25 +585,26 @@ type readErrorConn struct {
net.Conn net.Conn
} }
func (r *readErrorConn) Read(p []byte) (int, error) { func (*readErrorConn) Read(_ []byte) (int, error) {
return 0, fmt.Errorf("error") return 0, fmt.Errorf("error")
} }
func (r *readErrorConn) Write(p []byte) (int, error) { func (*readErrorConn) Write(p []byte) (int, error) {
return len(p), nil return len(p), nil
} }
func (r *readErrorConn) Close() error { func (*readErrorConn) Close() error {
return nil return nil
} }
func (r *readErrorConn) LocalAddr() net.Addr { func (*readErrorConn) LocalAddr() net.Addr {
return nil return nil
} }
func (r *readErrorConn) RemoteAddr() net.Addr { func (*readErrorConn) RemoteAddr() net.Addr {
return nil return nil
} }
func Test_Client_Agent_RetryIf(t *testing.T) { func Test_Client_Agent_RetryIf(t *testing.T) {
t.Parallel() t.Parallel()
@ -783,7 +788,10 @@ func Test_Client_Agent_MultipartForm_SendFiles(t *testing.T) {
buf := make([]byte, fh1.Size) buf := make([]byte, fh1.Size)
f, err := fh1.Open() f, err := fh1.Open()
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
defer func() { _ = f.Close() }() defer func() {
err := f.Close()
utils.AssertEqual(t, nil, err)
}()
_, err = f.Read(buf) _, err = f.Read(buf)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "form file", string(buf)) utils.AssertEqual(t, "form file", string(buf))
@ -831,13 +839,16 @@ func checkFormFile(t *testing.T, fh *multipart.FileHeader, filename string) {
basename := filepath.Base(filename) basename := filepath.Base(filename)
utils.AssertEqual(t, fh.Filename, basename) utils.AssertEqual(t, fh.Filename, basename)
b1, err := os.ReadFile(filename) b1, err := os.ReadFile(filename) //nolint:gosec // We're in a test so reading user-provided files by name is fine
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
b2 := make([]byte, fh.Size) b2 := make([]byte, fh.Size)
f, err := fh.Open() f, err := fh.Open()
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
defer func() { _ = f.Close() }() defer func() {
err := f.Close()
utils.AssertEqual(t, nil, err)
}()
_, err = f.Read(b2) _, err = f.Read(b2)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, b1, b2) utils.AssertEqual(t, b1, b2)
@ -962,6 +973,7 @@ func Test_Client_Agent_InsecureSkipVerify(t *testing.T) {
cer, err := tls.LoadX509KeyPair("./.github/testdata/ssl.pem", "./.github/testdata/ssl.key") cer, err := tls.LoadX509KeyPair("./.github/testdata/ssl.pem", "./.github/testdata/ssl.key")
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
//nolint:gosec // We're in a test so using old ciphers is fine
serverTLSConf := &tls.Config{ serverTLSConf := &tls.Config{
Certificates: []tls.Certificate{cer}, Certificates: []tls.Certificate{cer},
} }
@ -1137,7 +1149,7 @@ func Test_Client_Agent_Struct(t *testing.T) {
defer ReleaseAgent(a) defer ReleaseAgent(a)
defer a.ConnectionClose() defer a.ConnectionClose()
request := a.Request() request := a.Request()
request.Header.SetMethod("GET") request.Header.SetMethod(MethodGet)
request.SetRequestURI("http://example.com") request.SetRequestURI("http://example.com")
err := a.Parse() err := a.Parse()
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
@ -1198,8 +1210,8 @@ type errorMultipartWriter struct {
count int count int
} }
func (e *errorMultipartWriter) Boundary() string { return "myBoundary" } func (*errorMultipartWriter) Boundary() string { return "myBoundary" }
func (e *errorMultipartWriter) SetBoundary(_ string) error { return nil } func (*errorMultipartWriter) SetBoundary(_ string) error { return nil }
func (e *errorMultipartWriter) CreateFormFile(_, _ string) (io.Writer, error) { func (e *errorMultipartWriter) CreateFormFile(_, _ string) (io.Writer, error) {
if e.count == 0 { if e.count == 0 {
e.count++ e.count++
@ -1207,8 +1219,8 @@ func (e *errorMultipartWriter) CreateFormFile(_, _ string) (io.Writer, error) {
} }
return errorWriter{}, nil return errorWriter{}, nil
} }
func (e *errorMultipartWriter) WriteField(_, _ string) error { return errors.New("WriteField error") } func (*errorMultipartWriter) WriteField(_, _ string) error { return errors.New("WriteField error") }
func (e *errorMultipartWriter) Close() error { return errors.New("Close error") } func (*errorMultipartWriter) Close() error { return errors.New("Close error") }
type errorWriter struct{} type errorWriter struct{}

View File

@ -53,6 +53,8 @@ type Colors struct {
} }
// DefaultColors Default color codes // DefaultColors Default color codes
//
//nolint:gochecknoglobals // Using a global var is fine here
var DefaultColors = Colors{ var DefaultColors = Colors{
Black: "\u001b[90m", Black: "\u001b[90m",
Red: "\u001b[91m", Red: "\u001b[91m",

188
ctx.go
View File

@ -27,10 +27,16 @@ import (
"github.com/gofiber/fiber/v2/internal/dictpool" "github.com/gofiber/fiber/v2/internal/dictpool"
"github.com/gofiber/fiber/v2/internal/schema" "github.com/gofiber/fiber/v2/internal/schema"
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
"github.com/valyala/bytebufferpool" "github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
const (
schemeHTTP = "http"
schemeHTTPS = "https"
)
// maxParams defines the maximum number of parameters per route. // maxParams defines the maximum number of parameters per route.
const maxParams = 30 const maxParams = 30
@ -45,6 +51,7 @@ const (
// userContextKey define the key name for storing context.Context in *fasthttp.RequestCtx // userContextKey define the key name for storing context.Context in *fasthttp.RequestCtx
const userContextKey = "__local_user_context__" const userContextKey = "__local_user_context__"
//nolint:gochecknoglobals // TODO: Do not use global vars here
var ( var (
// decoderPoolMap helps to improve BodyParser's, QueryParser's and ReqHeaderParser's performance // decoderPoolMap helps to improve BodyParser's, QueryParser's and ReqHeaderParser's performance
decoderPoolMap = map[string]*sync.Pool{} decoderPoolMap = map[string]*sync.Pool{}
@ -52,6 +59,7 @@ var (
tags = []string{queryTag, bodyTag, reqHeaderTag, paramsTag} tags = []string{queryTag, bodyTag, reqHeaderTag, paramsTag}
) )
//nolint:gochecknoinits // init() is used to initialize a global map variable
func init() { func init() {
for _, tag := range tags { for _, tag := range tags {
decoderPoolMap[tag] = &sync.Pool{New: func() interface{} { decoderPoolMap[tag] = &sync.Pool{New: func() interface{} {
@ -100,9 +108,10 @@ type TLSHandler struct {
} }
// GetClientInfo Callback function to set CHI // GetClientInfo Callback function to set CHI
// TODO: Why is this a getter which sets stuff?
func (t *TLSHandler) GetClientInfo(info *tls.ClientHelloInfo) (*tls.Certificate, error) { func (t *TLSHandler) GetClientInfo(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
t.clientHelloInfo = info t.clientHelloInfo = info
return nil, nil return nil, nil //nolint:nilnil // Not returning anything useful here is probably fine
} }
// Range data for c.Range // Range data for c.Range
@ -151,7 +160,10 @@ type ParserConfig struct {
// AcquireCtx retrieves a new Ctx from the pool. // AcquireCtx retrieves a new Ctx from the pool.
func (app *App) AcquireCtx(fctx *fasthttp.RequestCtx) *Ctx { func (app *App) AcquireCtx(fctx *fasthttp.RequestCtx) *Ctx {
c := app.pool.Get().(*Ctx) c, ok := app.pool.Get().(*Ctx)
if !ok {
panic(fmt.Errorf("failed to type-assert to *Ctx"))
}
// Set app reference // Set app reference
c.app = app c.app = app
// Reset route and handler index // Reset route and handler index
@ -388,7 +400,6 @@ func (c *Ctx) BodyParser(out interface{}) error {
} else { } else {
data[k] = append(data[k], v) data[k] = append(data[k], v)
} }
}) })
return c.parseToStruct(bodyTag, out, data) return c.parseToStruct(bodyTag, out, data)
@ -401,7 +412,10 @@ func (c *Ctx) BodyParser(out interface{}) error {
return c.parseToStruct(bodyTag, out, data.Value) return c.parseToStruct(bodyTag, out, data.Value)
} }
if strings.HasPrefix(ctype, MIMETextXML) || strings.HasPrefix(ctype, MIMEApplicationXML) { if strings.HasPrefix(ctype, MIMETextXML) || strings.HasPrefix(ctype, MIMEApplicationXML) {
return xml.Unmarshal(c.Body(), out) if err := xml.Unmarshal(c.Body(), out); err != nil {
return fmt.Errorf("failed to unmarshal: %w", err)
}
return nil
} }
// No suitable content type found // No suitable content type found
return ErrUnprocessableEntity return ErrUnprocessableEntity
@ -673,8 +687,11 @@ func (c *Ctx) Hostname() string {
// Port returns the remote port of the request. // Port returns the remote port of the request.
func (c *Ctx) Port() string { func (c *Ctx) Port() string {
port := c.fasthttp.RemoteAddr().(*net.TCPAddr).Port tcpaddr, ok := c.fasthttp.RemoteAddr().(*net.TCPAddr)
return strconv.Itoa(port) if !ok {
panic(fmt.Errorf("failed to type-assert to *net.TCPAddr"))
}
return strconv.Itoa(tcpaddr.Port)
} }
// IP returns the remote IP address of the request. // IP returns the remote IP address of the request.
@ -691,13 +708,16 @@ func (c *Ctx) IP() string {
// extractIPsFromHeader will return a slice of IPs it found given a header name in the order they appear. // extractIPsFromHeader will return a slice of IPs it found given a header name in the order they appear.
// When IP validation is enabled, any invalid IPs will be omitted. // When IP validation is enabled, any invalid IPs will be omitted.
func (c *Ctx) extractIPsFromHeader(header string) []string { func (c *Ctx) extractIPsFromHeader(header string) []string {
// TODO: Reuse the c.extractIPFromHeader func somehow in here
headerValue := c.Get(header) headerValue := c.Get(header)
// We can't know how many IPs we will return, but we will try to guess with this constant division. // We can't know how many IPs we will return, but we will try to guess with this constant division.
// Counting ',' makes function slower for about 50ns in general case. // Counting ',' makes function slower for about 50ns in general case.
estimatedCount := len(headerValue) / 8 const maxEstimatedCount = 8
if estimatedCount > 8 { estimatedCount := len(headerValue) / maxEstimatedCount
estimatedCount = 8 // Avoid big allocation on big header if estimatedCount > maxEstimatedCount {
estimatedCount = maxEstimatedCount // Avoid big allocation on big header
} }
ipsFound := make([]string, 0, estimatedCount) ipsFound := make([]string, 0, estimatedCount)
@ -707,11 +727,10 @@ func (c *Ctx) extractIPsFromHeader(header string) []string {
iploop: iploop:
for { for {
v4 := false var v4, v6 bool
v6 := false
// Manually splitting string without allocating slice, working with parts directly // Manually splitting string without allocating slice, working with parts directly
i, j = j+1, j+2 i, j = j+1, j+2 //nolint:gomnd // Using these values is fine
if j > len(headerValue) { if j > len(headerValue) {
break break
@ -758,9 +777,10 @@ func (c *Ctx) extractIPFromHeader(header string) string {
iploop: iploop:
for { for {
v4 := false var v4, v6 bool
v6 := false
i, j = j+1, j+2 // Manually splitting string without allocating slice, working with parts directly
i, j = j+1, j+2 //nolint:gomnd // Using these values is fine
if j > len(headerValue) { if j > len(headerValue) {
break break
@ -793,14 +813,14 @@ func (c *Ctx) extractIPFromHeader(header string) string {
return c.fasthttp.RemoteIP().String() return c.fasthttp.RemoteIP().String()
} }
// default behaviour if IP validation is not enabled is just to return whatever value is // default behavior if IP validation is not enabled is just to return whatever value is
// in the proxy header. Even if it is empty or invalid // in the proxy header. Even if it is empty or invalid
return c.Get(c.app.config.ProxyHeader) return c.Get(c.app.config.ProxyHeader)
} }
// IPs returns a string slice of IP addresses specified in the X-Forwarded-For request header. // IPs returns a string slice of IP addresses specified in the X-Forwarded-For request header.
// When IP validation is enabled, only valid IPs are returned. // When IP validation is enabled, only valid IPs are returned.
func (c *Ctx) IPs() (ips []string) { func (c *Ctx) IPs() []string {
return c.extractIPsFromHeader(HeaderXForwardedFor) return c.extractIPsFromHeader(HeaderXForwardedFor)
} }
@ -839,7 +859,7 @@ func (c *Ctx) JSON(data interface{}) error {
func (c *Ctx) JSONP(data interface{}, callback ...string) error { func (c *Ctx) JSONP(data interface{}, callback ...string) error {
raw, err := json.Marshal(data) raw, err := json.Marshal(data)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to marshal: %w", err)
} }
var result, cb string var result, cb string
@ -877,11 +897,11 @@ func (c *Ctx) Links(link ...string) {
bb := bytebufferpool.Get() bb := bytebufferpool.Get()
for i := range link { for i := range link {
if i%2 == 0 { if i%2 == 0 {
_ = bb.WriteByte('<') _ = bb.WriteByte('<') //nolint:errcheck // This will never fail
_, _ = bb.WriteString(link[i]) _, _ = bb.WriteString(link[i]) //nolint:errcheck // This will never fail
_ = bb.WriteByte('>') _ = bb.WriteByte('>') //nolint:errcheck // This will never fail
} else { } else {
_, _ = bb.WriteString(`; rel="` + link[i] + `",`) _, _ = bb.WriteString(`; rel="` + link[i] + `",`) //nolint:errcheck // This will never fail
} }
} }
c.setCanonical(HeaderLink, utils.TrimRight(c.app.getString(bb.Bytes()), ',')) c.setCanonical(HeaderLink, utils.TrimRight(c.app.getString(bb.Bytes()), ','))
@ -890,7 +910,7 @@ func (c *Ctx) Links(link ...string) {
// Locals makes it possible to pass interface{} values under keys scoped to the request // Locals makes it possible to pass interface{} values under keys scoped to the request
// and therefore available to all following routes that match the request. // and therefore available to all following routes that match the request.
func (c *Ctx) Locals(key interface{}, value ...interface{}) (val interface{}) { func (c *Ctx) Locals(key interface{}, value ...interface{}) interface{} {
if len(value) == 0 { if len(value) == 0 {
return c.fasthttp.UserValue(key) return c.fasthttp.UserValue(key)
} }
@ -933,9 +953,10 @@ func (c *Ctx) ClientHelloInfo() *tls.ClientHelloInfo {
} }
// Next executes the next method in the stack that matches the current route. // Next executes the next method in the stack that matches the current route.
func (c *Ctx) Next() (err error) { func (c *Ctx) Next() error {
// Increment handler index // Increment handler index
c.indexHandler++ c.indexHandler++
var err error
// Did we executed all route handlers? // Did we executed all route handlers?
if c.indexHandler < len(c.route.Handlers) { if c.indexHandler < len(c.route.Handlers) {
// Continue route stack // Continue route stack
@ -947,7 +968,7 @@ func (c *Ctx) Next() (err error) {
return err return err
} }
// RestartRouting instead of going to the next handler. This may be usefull after // RestartRouting instead of going to the next handler. This may be useful after
// changing the request path. Note that handlers might be executed again. // changing the request path. Note that handlers might be executed again.
func (c *Ctx) RestartRouting() error { func (c *Ctx) RestartRouting() error {
c.indexRoute = -1 c.indexRoute = -1
@ -1017,9 +1038,8 @@ func (c *Ctx) ParamsInt(key string, defaultValue ...int) (int, error) {
if err != nil { if err != nil {
if len(defaultValue) > 0 { if len(defaultValue) > 0 {
return defaultValue[0], nil return defaultValue[0], nil
} else {
return 0, err
} }
return 0, fmt.Errorf("failed to convert: %w", err)
} }
return value, nil return value, nil
@ -1044,15 +1064,16 @@ func (c *Ctx) Path(override ...string) string {
// Please use Config.EnableTrustedProxyCheck to prevent header spoofing, in case when your app is behind the proxy. // Please use Config.EnableTrustedProxyCheck to prevent header spoofing, in case when your app is behind the proxy.
func (c *Ctx) Protocol() string { func (c *Ctx) Protocol() string {
if c.fasthttp.IsTLS() { if c.fasthttp.IsTLS() {
return "https" return schemeHTTPS
} }
if !c.IsProxyTrusted() { if !c.IsProxyTrusted() {
return "http" return schemeHTTP
} }
scheme := "http" scheme := schemeHTTP
const lenXHeaderName = 12
c.fasthttp.Request.Header.VisitAll(func(key, val []byte) { c.fasthttp.Request.Header.VisitAll(func(key, val []byte) {
if len(key) < 12 { if len(key) < lenXHeaderName {
return // Neither "X-Forwarded-" nor "X-Url-Scheme" return // Neither "X-Forwarded-" nor "X-Url-Scheme"
} }
switch { switch {
@ -1067,7 +1088,7 @@ func (c *Ctx) Protocol() string {
scheme = v scheme = v
} }
} else if bytes.Equal(key, []byte(HeaderXForwardedSsl)) && bytes.Equal(val, []byte("on")) { } else if bytes.Equal(key, []byte(HeaderXForwardedSsl)) && bytes.Equal(val, []byte("on")) {
scheme = "https" scheme = schemeHTTPS
} }
case bytes.Equal(key, []byte(HeaderXUrlScheme)): case bytes.Equal(key, []byte(HeaderXUrlScheme)):
@ -1100,9 +1121,8 @@ func (c *Ctx) QueryInt(key string, defaultValue ...int) int {
if err != nil { if err != nil {
if len(defaultValue) > 0 { if len(defaultValue) > 0 {
return defaultValue[0] return defaultValue[0]
} else {
return 0
} }
return 0
} }
return value return value
@ -1133,7 +1153,6 @@ func (c *Ctx) QueryParser(out interface{}) error {
} else { } else {
data[k] = append(data[k], v) data[k] = append(data[k], v)
} }
}) })
if err != nil { if err != nil {
@ -1150,10 +1169,9 @@ func parseParamSquareBrackets(k string) (string, error) {
kbytes := []byte(k) kbytes := []byte(k)
for i, b := range kbytes { for i, b := range kbytes {
if b == '[' && kbytes[i+1] != ']' { if b == '[' && kbytes[i+1] != ']' {
if err := bb.WriteByte('.'); err != nil { if err := bb.WriteByte('.'); err != nil {
return "", err return "", fmt.Errorf("failed to write: %w", err)
} }
} }
@ -1162,7 +1180,7 @@ func parseParamSquareBrackets(k string) (string, error) {
} }
if err := bb.WriteByte(b); err != nil { if err := bb.WriteByte(b); err != nil {
return "", err return "", fmt.Errorf("failed to write: %w", err)
} }
} }
@ -1184,21 +1202,27 @@ func (c *Ctx) ReqHeaderParser(out interface{}) error {
} else { } else {
data[k] = append(data[k], v) data[k] = append(data[k], v)
} }
}) })
return c.parseToStruct(reqHeaderTag, out, data) return c.parseToStruct(reqHeaderTag, out, data)
} }
func (c *Ctx) parseToStruct(aliasTag string, out interface{}, data map[string][]string) error { func (*Ctx) parseToStruct(aliasTag string, out interface{}, data map[string][]string) error {
// Get decoder from pool // Get decoder from pool
schemaDecoder := decoderPoolMap[aliasTag].Get().(*schema.Decoder) schemaDecoder, ok := decoderPoolMap[aliasTag].Get().(*schema.Decoder)
if !ok {
panic(fmt.Errorf("failed to type-assert to *schema.Decoder"))
}
defer decoderPoolMap[aliasTag].Put(schemaDecoder) defer decoderPoolMap[aliasTag].Put(schemaDecoder)
// Set alias tag // Set alias tag
schemaDecoder.SetAliasTag(aliasTag) schemaDecoder.SetAliasTag(aliasTag)
return schemaDecoder.Decode(out, data) if err := schemaDecoder.Decode(out, data); err != nil {
return fmt.Errorf("failed to decode: %w", err)
}
return nil
} }
func equalFieldType(out interface{}, kind reflect.Kind, key string) bool { func equalFieldType(out interface{}, kind reflect.Kind, key string) bool {
@ -1248,24 +1272,23 @@ var (
) )
// Range returns a struct containing the type and a slice of ranges. // Range returns a struct containing the type and a slice of ranges.
func (c *Ctx) Range(size int) (rangeData Range, err error) { func (c *Ctx) Range(size int) (Range, error) {
var rangeData Range
rangeStr := c.Get(HeaderRange) rangeStr := c.Get(HeaderRange)
if rangeStr == "" || !strings.Contains(rangeStr, "=") { if rangeStr == "" || !strings.Contains(rangeStr, "=") {
err = ErrRangeMalformed return rangeData, ErrRangeMalformed
return
} }
data := strings.Split(rangeStr, "=") data := strings.Split(rangeStr, "=")
if len(data) != 2 { const expectedDataParts = 2
err = ErrRangeMalformed if len(data) != expectedDataParts {
return return rangeData, ErrRangeMalformed
} }
rangeData.Type = data[0] rangeData.Type = data[0]
arr := strings.Split(data[1], ",") arr := strings.Split(data[1], ",")
for i := 0; i < len(arr); i++ { for i := 0; i < len(arr); i++ {
item := strings.Split(arr[i], "-") item := strings.Split(arr[i], "-")
if len(item) == 1 { if len(item) == 1 {
err = ErrRangeMalformed return rangeData, ErrRangeMalformed
return
} }
start, startErr := strconv.Atoi(item[0]) start, startErr := strconv.Atoi(item[0])
end, endErr := strconv.Atoi(item[1]) end, endErr := strconv.Atoi(item[1])
@ -1290,11 +1313,10 @@ func (c *Ctx) Range(size int) (rangeData Range, err error) {
}) })
} }
if len(rangeData.Ranges) < 1 { if len(rangeData.Ranges) < 1 {
err = ErrRangeUnsatisfiable return rangeData, ErrRangeUnsatisfiable
return
} }
return return rangeData, nil
} }
// Redirect to the URL derived from the specified path, with specified status. // Redirect to the URL derived from the specified path, with specified status.
@ -1330,7 +1352,7 @@ func (c *Ctx) getLocationFromRoute(route Route, params Map) (string, error) {
if !segment.IsParam { if !segment.IsParam {
_, err := buf.WriteString(segment.Const) _, err := buf.WriteString(segment.Const)
if err != nil { if err != nil {
return "", err return "", fmt.Errorf("failed to write string: %w", err)
} }
continue continue
} }
@ -1341,7 +1363,7 @@ func (c *Ctx) getLocationFromRoute(route Route, params Map) (string, error) {
if isSame || isGreedy { if isSame || isGreedy {
_, err := buf.WriteString(utils.ToString(val)) _, err := buf.WriteString(utils.ToString(val))
if err != nil { if err != nil {
return "", err return "", fmt.Errorf("failed to write string: %w", err)
} }
} }
} }
@ -1373,10 +1395,10 @@ func (c *Ctx) RedirectToRoute(routeName string, params Map, status ...int) error
i := 1 i := 1
for k, v := range queries { for k, v := range queries {
_, _ = queryText.WriteString(k + "=" + v) _, _ = queryText.WriteString(k + "=" + v) //nolint:errcheck // This will never fail
if i != len(queries) { if i != len(queries) {
_, _ = queryText.WriteString("&") _, _ = queryText.WriteString("&") //nolint:errcheck // This will never fail
} }
i++ i++
} }
@ -1399,7 +1421,6 @@ func (c *Ctx) RedirectBack(fallback string, status ...int) error {
// Render a template with data and sends a text/html response. // Render a template with data and sends a text/html response.
// We support the following engines: html, amber, handlebars, mustache, pug // We support the following engines: html, amber, handlebars, mustache, pug
func (c *Ctx) Render(name string, bind interface{}, layouts ...string) error { func (c *Ctx) Render(name string, bind interface{}, layouts ...string) error {
var err error
// Get new buffer from pool // Get new buffer from pool
buf := bytebufferpool.Get() buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf) defer bytebufferpool.Put(buf)
@ -1421,7 +1442,7 @@ func (c *Ctx) Render(name string, bind interface{}, layouts ...string) error {
// Render template from Views // Render template from Views
if app.config.Views != nil { if app.config.Views != nil {
if err := app.config.Views.Render(buf, name, bind, layouts...); err != nil { if err := app.config.Views.Render(buf, name, bind, layouts...); err != nil {
return err return fmt.Errorf("failed to render: %w", err)
} }
rendered = true rendered = true
@ -1433,17 +1454,18 @@ func (c *Ctx) Render(name string, bind interface{}, layouts ...string) error {
if !rendered { if !rendered {
// Render raw template using 'name' as filepath if no engine is set // Render raw template using 'name' as filepath if no engine is set
var tmpl *template.Template var tmpl *template.Template
if _, err = readContent(buf, name); err != nil { if _, err := readContent(buf, name); err != nil {
return err return err
} }
// Parse template // Parse template
if tmpl, err = template.New("").Parse(c.app.getString(buf.Bytes())); err != nil { tmpl, err := template.New("").Parse(c.app.getString(buf.Bytes()))
return err if err != nil {
return fmt.Errorf("failed to parse: %w", err)
} }
buf.Reset() buf.Reset()
// Render template // Render template
if err = tmpl.Execute(buf, bind); err != nil { if err := tmpl.Execute(buf, bind); err != nil {
return err return fmt.Errorf("failed to execute: %w", err)
} }
} }
@ -1451,8 +1473,8 @@ func (c *Ctx) Render(name string, bind interface{}, layouts ...string) error {
c.fasthttp.Response.Header.SetContentType(MIMETextHTMLCharsetUTF8) c.fasthttp.Response.Header.SetContentType(MIMETextHTMLCharsetUTF8)
// Set rendered template to body // Set rendered template to body
c.fasthttp.Response.SetBody(buf.Bytes()) c.fasthttp.Response.SetBody(buf.Bytes())
// Return err if exist
return err return nil
} }
func (c *Ctx) renderExtensions(bind interface{}) { func (c *Ctx) renderExtensions(bind interface{}) {
@ -1501,28 +1523,32 @@ func (c *Ctx) Route() *Route {
} }
// SaveFile saves any multipart file to disk. // SaveFile saves any multipart file to disk.
func (c *Ctx) SaveFile(fileheader *multipart.FileHeader, path string) error { func (*Ctx) SaveFile(fileheader *multipart.FileHeader, path string) error {
return fasthttp.SaveMultipartFile(fileheader, path) return fasthttp.SaveMultipartFile(fileheader, path)
} }
// SaveFileToStorage saves any multipart file to an external storage system. // SaveFileToStorage saves any multipart file to an external storage system.
func (c *Ctx) SaveFileToStorage(fileheader *multipart.FileHeader, path string, storage Storage) error { func (*Ctx) SaveFileToStorage(fileheader *multipart.FileHeader, path string, storage Storage) error {
file, err := fileheader.Open() file, err := fileheader.Open()
if err != nil { if err != nil {
return err return fmt.Errorf("failed to open: %w", err)
} }
content, err := io.ReadAll(file) content, err := io.ReadAll(file)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to read: %w", err)
} }
return storage.Set(path, content, 0) if err := storage.Set(path, content, 0); err != nil {
return fmt.Errorf("failed to store: %w", err)
}
return nil
} }
// Secure returns whether a secure connection was established. // Secure returns whether a secure connection was established.
func (c *Ctx) Secure() bool { func (c *Ctx) Secure() bool {
return c.Protocol() == "https" return c.Protocol() == schemeHTTPS
} }
// Send sets the HTTP response body without copying it. // Send sets the HTTP response body without copying it.
@ -1533,6 +1559,7 @@ func (c *Ctx) Send(body []byte) error {
return nil return nil
} }
//nolint:gochecknoglobals // TODO: Do not use global vars here
var ( var (
sendFileOnce sync.Once sendFileOnce sync.Once
sendFileFS *fasthttp.FS sendFileFS *fasthttp.FS
@ -1548,6 +1575,7 @@ func (c *Ctx) SendFile(file string, compress ...bool) error {
// https://github.com/valyala/fasthttp/blob/c7576cc10cabfc9c993317a2d3f8355497bea156/fs.go#L129-L134 // https://github.com/valyala/fasthttp/blob/c7576cc10cabfc9c993317a2d3f8355497bea156/fs.go#L129-L134
sendFileOnce.Do(func() { sendFileOnce.Do(func() {
const cacheDuration = 10 * time.Second
sendFileFS = &fasthttp.FS{ sendFileFS = &fasthttp.FS{
Root: "", Root: "",
AllowEmptyRoot: true, AllowEmptyRoot: true,
@ -1555,7 +1583,7 @@ func (c *Ctx) SendFile(file string, compress ...bool) error {
AcceptByteRange: true, AcceptByteRange: true,
Compress: true, Compress: true,
CompressedFileSuffix: c.app.config.CompressedFileSuffix, CompressedFileSuffix: c.app.config.CompressedFileSuffix,
CacheDuration: 10 * time.Second, CacheDuration: cacheDuration,
IndexNames: []string{"index.html"}, IndexNames: []string{"index.html"},
PathNotFound: func(ctx *fasthttp.RequestCtx) { PathNotFound: func(ctx *fasthttp.RequestCtx) {
ctx.Response.SetStatusCode(StatusNotFound) ctx.Response.SetStatusCode(StatusNotFound)
@ -1579,7 +1607,7 @@ func (c *Ctx) SendFile(file string, compress ...bool) error {
var err error var err error
file = filepath.FromSlash(file) file = filepath.FromSlash(file)
if file, err = filepath.Abs(file); err != nil { if file, err = filepath.Abs(file); err != nil {
return err return fmt.Errorf("failed to determine abs file path: %w", err)
} }
if hasTrailingSlash { if hasTrailingSlash {
file += "/" file += "/"
@ -1644,11 +1672,11 @@ func (c *Ctx) SendStream(stream io.Reader, size ...int) error {
} }
// Set sets the response's HTTP header field to the specified key, value. // Set sets the response's HTTP header field to the specified key, value.
func (c *Ctx) Set(key string, val string) { func (c *Ctx) Set(key, val string) {
c.fasthttp.Response.Header.Set(key, val) c.fasthttp.Response.Header.Set(key, val)
} }
func (c *Ctx) setCanonical(key string, val string) { func (c *Ctx) setCanonical(key, val string) {
c.fasthttp.Response.Header.SetCanonical(c.app.getBytes(key), c.app.getBytes(val)) c.fasthttp.Response.Header.SetCanonical(c.app.getBytes(key), c.app.getBytes(val))
} }
@ -1719,6 +1747,7 @@ func (c *Ctx) Write(p []byte) (int, error) {
// Writef appends f & a into response body writer. // Writef appends f & a into response body writer.
func (c *Ctx) Writef(f string, a ...interface{}) (int, error) { func (c *Ctx) Writef(f string, a ...interface{}) (int, error) {
//nolint:wrapcheck // This must not be wrapped
return fmt.Fprintf(c.fasthttp.Response.BodyWriter(), f, a...) return fmt.Fprintf(c.fasthttp.Response.BodyWriter(), f, a...)
} }
@ -1760,8 +1789,9 @@ func (c *Ctx) configDependentPaths() {
// Define the path for dividing routes into areas for fast tree detection, so that fewer routes need to be traversed, // Define the path for dividing routes into areas for fast tree detection, so that fewer routes need to be traversed,
// since the first three characters area select a list of routes // since the first three characters area select a list of routes
c.treePath = c.treePath[0:0] c.treePath = c.treePath[0:0]
if len(c.detectionPath) >= 3 { const maxDetectionPaths = 3
c.treePath = c.detectionPath[:3] if len(c.detectionPath) >= maxDetectionPaths {
c.treePath = c.detectionPath[:maxDetectionPaths]
} }
} }
@ -1786,7 +1816,7 @@ func (c *Ctx) IsProxyTrusted() bool {
} }
// IsLocalHost will return true if address is a localhost address. // IsLocalHost will return true if address is a localhost address.
func (c *Ctx) isLocalHost(address string) bool { func (*Ctx) isLocalHost(address string) bool {
localHosts := []string{"127.0.0.1", "0.0.0.0", "::1"} localHosts := []string{"127.0.0.1", "0.0.0.0", "::1"}
for _, h := range localHosts { for _, h := range localHosts {
if strings.Contains(address, h) { if strings.Contains(address, h) {

View File

@ -2,6 +2,7 @@
// 🤖 Github Repository: https://github.com/gofiber/fiber // 🤖 Github Repository: https://github.com/gofiber/fiber
// 📌 API Documentation: https://docs.gofiber.io // 📌 API Documentation: https://docs.gofiber.io
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
package fiber package fiber
import ( import (
@ -28,6 +29,7 @@ import (
"github.com/gofiber/fiber/v2/internal/storage/memory" "github.com/gofiber/fiber/v2/internal/storage/memory"
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
"github.com/valyala/bytebufferpool" "github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -361,8 +363,10 @@ func Test_Ctx_BodyParser(t *testing.T) {
{ {
var gzipJSON bytes.Buffer var gzipJSON bytes.Buffer
w := gzip.NewWriter(&gzipJSON) w := gzip.NewWriter(&gzipJSON)
_, _ = w.Write([]byte(`{"name":"john"}`)) _, err := w.Write([]byte(`{"name":"john"}`))
_ = w.Close() utils.AssertEqual(t, nil, err)
err = w.Close()
utils.AssertEqual(t, nil, err)
c.Request().Header.SetContentType(MIMEApplicationJSON) c.Request().Header.SetContentType(MIMEApplicationJSON)
c.Request().Header.Set(HeaderContentEncoding, "gzip") c.Request().Header.Set(HeaderContentEncoding, "gzip")
@ -431,9 +435,7 @@ func Test_Ctx_ParamParser(t *testing.T) {
UserID uint `params:"userId"` UserID uint `params:"userId"`
RoleID uint `params:"roleId"` RoleID uint `params:"roleId"`
} }
var ( d := new(Demo)
d = new(Demo)
)
if err := ctx.ParamsParser(d); err != nil { if err := ctx.ParamsParser(d); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -519,7 +521,7 @@ func Benchmark_Ctx_BodyParser_JSON(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
_ = c.BodyParser(d) _ = c.BodyParser(d) //nolint:errcheck // It is fine to ignore the error here as we check it once further below
} }
utils.AssertEqual(b, nil, c.BodyParser(d)) utils.AssertEqual(b, nil, c.BodyParser(d))
utils.AssertEqual(b, "john", d.Name) utils.AssertEqual(b, "john", d.Name)
@ -543,7 +545,7 @@ func Benchmark_Ctx_BodyParser_XML(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
_ = c.BodyParser(d) _ = c.BodyParser(d) //nolint:errcheck // It is fine to ignore the error here as we check it once further below
} }
utils.AssertEqual(b, nil, c.BodyParser(d)) utils.AssertEqual(b, nil, c.BodyParser(d))
utils.AssertEqual(b, "john", d.Name) utils.AssertEqual(b, "john", d.Name)
@ -567,7 +569,7 @@ func Benchmark_Ctx_BodyParser_Form(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
_ = c.BodyParser(d) _ = c.BodyParser(d) //nolint:errcheck // It is fine to ignore the error here as we check it once further below
} }
utils.AssertEqual(b, nil, c.BodyParser(d)) utils.AssertEqual(b, nil, c.BodyParser(d))
utils.AssertEqual(b, "john", d.Name) utils.AssertEqual(b, "john", d.Name)
@ -592,7 +594,7 @@ func Benchmark_Ctx_BodyParser_MultipartForm(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
_ = c.BodyParser(d) _ = c.BodyParser(d) //nolint:errcheck // It is fine to ignore the error here as we check it once further below
} }
utils.AssertEqual(b, nil, c.BodyParser(d)) utils.AssertEqual(b, nil, c.BodyParser(d))
utils.AssertEqual(b, "john", d.Name) utils.AssertEqual(b, "john", d.Name)
@ -879,12 +881,13 @@ func Test_Ctx_FormFile(t *testing.T) {
f, err := fh.Open() f, err := fh.Open()
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
defer func() {
utils.AssertEqual(t, nil, f.Close())
}()
b := new(bytes.Buffer) b := new(bytes.Buffer)
_, err = io.Copy(b, f) _, err = io.Copy(b, f)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
f.Close()
utils.AssertEqual(t, "hello world", b.String()) utils.AssertEqual(t, "hello world", b.String())
return nil return nil
}) })
@ -897,8 +900,7 @@ func Test_Ctx_FormFile(t *testing.T) {
_, err = ioWriter.Write([]byte("hello world")) _, err = ioWriter.Write([]byte("hello world"))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, nil, writer.Close())
writer.Close()
req := httptest.NewRequest(MethodPost, "/test", body) req := httptest.NewRequest(MethodPost, "/test", body)
req.Header.Set(HeaderContentType, writer.FormDataContentType()) req.Header.Set(HeaderContentType, writer.FormDataContentType())
@ -921,10 +923,9 @@ func Test_Ctx_FormValue(t *testing.T) {
body := &bytes.Buffer{} body := &bytes.Buffer{}
writer := multipart.NewWriter(body) writer := multipart.NewWriter(body)
utils.AssertEqual(t, nil, writer.WriteField("name", "john")) utils.AssertEqual(t, nil, writer.WriteField("name", "john"))
utils.AssertEqual(t, nil, writer.Close())
writer.Close()
req := httptest.NewRequest(MethodPost, "/test", body) req := httptest.NewRequest(MethodPost, "/test", body)
req.Header.Set("Content-Type", fmt.Sprintf("multipart/form-data; boundary=%s", writer.Boundary())) req.Header.Set("Content-Type", fmt.Sprintf("multipart/form-data; boundary=%s", writer.Boundary()))
req.Header.Set("Content-Length", strconv.Itoa(len(body.Bytes()))) req.Header.Set("Content-Length", strconv.Itoa(len(body.Bytes())))
@ -1240,7 +1241,7 @@ func Test_Ctx_IP(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{}) c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c) defer app.ReleaseCtx(c)
// default behaviour will return the remote IP from the stack // default behavior will return the remote IP from the stack
utils.AssertEqual(t, "0.0.0.0", c.IP()) utils.AssertEqual(t, "0.0.0.0", c.IP())
// X-Forwarded-For is set, but it is ignored because proxyHeader is not set // X-Forwarded-For is set, but it is ignored because proxyHeader is not set
@ -1252,7 +1253,7 @@ func Test_Ctx_IP(t *testing.T) {
func Test_Ctx_IP_ProxyHeader(t *testing.T) { func Test_Ctx_IP_ProxyHeader(t *testing.T) {
t.Parallel() t.Parallel()
// make sure that the same behaviour exists for different proxy header names // make sure that the same behavior exists for different proxy header names
proxyHeaderNames := []string{"Real-Ip", HeaderXForwardedFor} proxyHeaderNames := []string{"Real-Ip", HeaderXForwardedFor}
for _, proxyHeaderName := range proxyHeaderNames { for _, proxyHeaderName := range proxyHeaderNames {
@ -1286,7 +1287,7 @@ func Test_Ctx_IP_ProxyHeader(t *testing.T) {
func Test_Ctx_IP_ProxyHeader_With_IP_Validation(t *testing.T) { func Test_Ctx_IP_ProxyHeader_With_IP_Validation(t *testing.T) {
t.Parallel() t.Parallel()
// make sure that the same behaviour exists for different proxy header names // make sure that the same behavior exists for different proxy header names
proxyHeaderNames := []string{"Real-Ip", HeaderXForwardedFor} proxyHeaderNames := []string{"Real-Ip", HeaderXForwardedFor}
for _, proxyHeaderName := range proxyHeaderNames { for _, proxyHeaderName := range proxyHeaderNames {
@ -1625,35 +1626,43 @@ func Test_Ctx_ClientHelloInfo(t *testing.T) {
}) })
// Test without TLS handler // Test without TLS handler
resp, _ := app.Test(httptest.NewRequest(MethodGet, "/ServerName", nil)) resp, err := app.Test(httptest.NewRequest(MethodGet, "/ServerName", nil))
body, _ := io.ReadAll(resp.Body) utils.AssertEqual(t, nil, err)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, []byte("ClientHelloInfo is nil"), body) utils.AssertEqual(t, []byte("ClientHelloInfo is nil"), body)
// Test with TLS Handler // Test with TLS Handler
const ( const (
PSSWithSHA256 = 0x0804 pssWithSHA256 = 0x0804
VersionTLS13 = 0x0304 versionTLS13 = 0x0304
) )
app.tlsHandler = &TLSHandler{clientHelloInfo: &tls.ClientHelloInfo{ app.tlsHandler = &TLSHandler{clientHelloInfo: &tls.ClientHelloInfo{
ServerName: "example.golang", ServerName: "example.golang",
SignatureSchemes: []tls.SignatureScheme{PSSWithSHA256}, SignatureSchemes: []tls.SignatureScheme{pssWithSHA256},
SupportedVersions: []uint16{VersionTLS13}, SupportedVersions: []uint16{versionTLS13},
}} }}
// Test ServerName // Test ServerName
resp, _ = app.Test(httptest.NewRequest(MethodGet, "/ServerName", nil)) resp, err = app.Test(httptest.NewRequest(MethodGet, "/ServerName", nil))
body, _ = io.ReadAll(resp.Body) utils.AssertEqual(t, nil, err)
body, err = io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, []byte("example.golang"), body) utils.AssertEqual(t, []byte("example.golang"), body)
// Test SignatureSchemes // Test SignatureSchemes
resp, _ = app.Test(httptest.NewRequest(MethodGet, "/SignatureSchemes", nil)) resp, err = app.Test(httptest.NewRequest(MethodGet, "/SignatureSchemes", nil))
body, _ = io.ReadAll(resp.Body) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "["+strconv.Itoa(PSSWithSHA256)+"]", string(body)) body, err = io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "["+strconv.Itoa(pssWithSHA256)+"]", string(body))
// Test SupportedVersions // Test SupportedVersions
resp, _ = app.Test(httptest.NewRequest(MethodGet, "/SupportedVersions", nil)) resp, err = app.Test(httptest.NewRequest(MethodGet, "/SupportedVersions", nil))
body, _ = io.ReadAll(resp.Body) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "["+strconv.Itoa(VersionTLS13)+"]", string(body)) body, err = io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "["+strconv.Itoa(versionTLS13)+"]", string(body))
} }
// go test -run Test_Ctx_InvalidMethod // go test -run Test_Ctx_InvalidMethod
@ -1688,10 +1697,9 @@ func Test_Ctx_MultipartForm(t *testing.T) {
body := &bytes.Buffer{} body := &bytes.Buffer{}
writer := multipart.NewWriter(body) writer := multipart.NewWriter(body)
utils.AssertEqual(t, nil, writer.WriteField("name", "john")) utils.AssertEqual(t, nil, writer.WriteField("name", "john"))
utils.AssertEqual(t, nil, writer.Close())
writer.Close()
req := httptest.NewRequest(MethodPost, "/test", body) req := httptest.NewRequest(MethodPost, "/test", body)
req.Header.Set(HeaderContentType, fmt.Sprintf("multipart/form-data; boundary=%s", writer.Boundary())) req.Header.Set(HeaderContentType, fmt.Sprintf("multipart/form-data; boundary=%s", writer.Boundary()))
req.Header.Set(HeaderContentLength, strconv.Itoa(len(body.Bytes()))) req.Header.Set(HeaderContentLength, strconv.Itoa(len(body.Bytes())))
@ -1706,8 +1714,8 @@ func Benchmark_Ctx_MultipartForm(b *testing.B) {
app := New() app := New()
app.Post("/", func(c *Ctx) error { app.Post("/", func(c *Ctx) error {
_, _ = c.MultipartForm() _, err := c.MultipartForm()
return nil return err
}) })
c := &fasthttp.RequestCtx{} c := &fasthttp.RequestCtx{}
@ -1889,11 +1897,16 @@ func Benchmark_Ctx_AllParams(b *testing.B) {
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
res = c.AllParams() res = c.AllParams()
} }
utils.AssertEqual(b, map[string]string{"param1": "john", utils.AssertEqual(
b,
map[string]string{
"param1": "john",
"param2": "doe", "param2": "doe",
"param3": "is", "param3": "is",
"param4": "awesome"}, "param4": "awesome",
res) },
res,
)
} }
// go test -v -run=^$ -bench=Benchmark_Ctx_ParamsParse -benchmem -count=4 // go test -v -run=^$ -bench=Benchmark_Ctx_ParamsParse -benchmem -count=4
@ -1964,31 +1977,31 @@ func Test_Ctx_Protocol(t *testing.T) {
c := app.AcquireCtx(freq) c := app.AcquireCtx(freq)
defer app.ReleaseCtx(c) defer app.ReleaseCtx(c)
c.Request().Header.Set(HeaderXForwardedProto, "https") c.Request().Header.Set(HeaderXForwardedProto, schemeHTTPS)
utils.AssertEqual(t, "https", c.Protocol()) utils.AssertEqual(t, schemeHTTPS, c.Protocol())
c.Request().Header.Reset() c.Request().Header.Reset()
c.Request().Header.Set(HeaderXForwardedProtocol, "https") c.Request().Header.Set(HeaderXForwardedProtocol, schemeHTTPS)
utils.AssertEqual(t, "https", c.Protocol()) utils.AssertEqual(t, schemeHTTPS, c.Protocol())
c.Request().Header.Reset() c.Request().Header.Reset()
c.Request().Header.Set(HeaderXForwardedProto, "https, http") c.Request().Header.Set(HeaderXForwardedProto, "https, http")
utils.AssertEqual(t, "https", c.Protocol()) utils.AssertEqual(t, schemeHTTPS, c.Protocol())
c.Request().Header.Reset() c.Request().Header.Reset()
c.Request().Header.Set(HeaderXForwardedProtocol, "https, http") c.Request().Header.Set(HeaderXForwardedProtocol, "https, http")
utils.AssertEqual(t, "https", c.Protocol()) utils.AssertEqual(t, schemeHTTPS, c.Protocol())
c.Request().Header.Reset() c.Request().Header.Reset()
c.Request().Header.Set(HeaderXForwardedSsl, "on") c.Request().Header.Set(HeaderXForwardedSsl, "on")
utils.AssertEqual(t, "https", c.Protocol()) utils.AssertEqual(t, schemeHTTPS, c.Protocol())
c.Request().Header.Reset() c.Request().Header.Reset()
c.Request().Header.Set(HeaderXUrlScheme, "https") c.Request().Header.Set(HeaderXUrlScheme, schemeHTTPS)
utils.AssertEqual(t, "https", c.Protocol()) utils.AssertEqual(t, schemeHTTPS, c.Protocol())
c.Request().Header.Reset() c.Request().Header.Reset()
utils.AssertEqual(t, "http", c.Protocol()) utils.AssertEqual(t, schemeHTTP, c.Protocol())
} }
// go test -v -run=^$ -bench=Benchmark_Ctx_Protocol -benchmem -count=4 // go test -v -run=^$ -bench=Benchmark_Ctx_Protocol -benchmem -count=4
@ -2002,7 +2015,7 @@ func Benchmark_Ctx_Protocol(b *testing.B) {
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
res = c.Protocol() res = c.Protocol()
} }
utils.AssertEqual(b, "http", res) utils.AssertEqual(b, schemeHTTP, res)
} }
// go test -run Test_Ctx_Protocol_TrustedProxy // go test -run Test_Ctx_Protocol_TrustedProxy
@ -2012,23 +2025,23 @@ func Test_Ctx_Protocol_TrustedProxy(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{}) c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c) defer app.ReleaseCtx(c)
c.Request().Header.Set(HeaderXForwardedProto, "https") c.Request().Header.Set(HeaderXForwardedProto, schemeHTTPS)
utils.AssertEqual(t, "https", c.Protocol()) utils.AssertEqual(t, schemeHTTPS, c.Protocol())
c.Request().Header.Reset() c.Request().Header.Reset()
c.Request().Header.Set(HeaderXForwardedProtocol, "https") c.Request().Header.Set(HeaderXForwardedProtocol, schemeHTTPS)
utils.AssertEqual(t, "https", c.Protocol()) utils.AssertEqual(t, schemeHTTPS, c.Protocol())
c.Request().Header.Reset() c.Request().Header.Reset()
c.Request().Header.Set(HeaderXForwardedSsl, "on") c.Request().Header.Set(HeaderXForwardedSsl, "on")
utils.AssertEqual(t, "https", c.Protocol()) utils.AssertEqual(t, schemeHTTPS, c.Protocol())
c.Request().Header.Reset() c.Request().Header.Reset()
c.Request().Header.Set(HeaderXUrlScheme, "https") c.Request().Header.Set(HeaderXUrlScheme, schemeHTTPS)
utils.AssertEqual(t, "https", c.Protocol()) utils.AssertEqual(t, schemeHTTPS, c.Protocol())
c.Request().Header.Reset() c.Request().Header.Reset()
utils.AssertEqual(t, "http", c.Protocol()) utils.AssertEqual(t, schemeHTTP, c.Protocol())
} }
// go test -run Test_Ctx_Protocol_TrustedProxyRange // go test -run Test_Ctx_Protocol_TrustedProxyRange
@ -2038,23 +2051,23 @@ func Test_Ctx_Protocol_TrustedProxyRange(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{}) c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c) defer app.ReleaseCtx(c)
c.Request().Header.Set(HeaderXForwardedProto, "https") c.Request().Header.Set(HeaderXForwardedProto, schemeHTTPS)
utils.AssertEqual(t, "https", c.Protocol()) utils.AssertEqual(t, schemeHTTPS, c.Protocol())
c.Request().Header.Reset() c.Request().Header.Reset()
c.Request().Header.Set(HeaderXForwardedProtocol, "https") c.Request().Header.Set(HeaderXForwardedProtocol, schemeHTTPS)
utils.AssertEqual(t, "https", c.Protocol()) utils.AssertEqual(t, schemeHTTPS, c.Protocol())
c.Request().Header.Reset() c.Request().Header.Reset()
c.Request().Header.Set(HeaderXForwardedSsl, "on") c.Request().Header.Set(HeaderXForwardedSsl, "on")
utils.AssertEqual(t, "https", c.Protocol()) utils.AssertEqual(t, schemeHTTPS, c.Protocol())
c.Request().Header.Reset() c.Request().Header.Reset()
c.Request().Header.Set(HeaderXUrlScheme, "https") c.Request().Header.Set(HeaderXUrlScheme, schemeHTTPS)
utils.AssertEqual(t, "https", c.Protocol()) utils.AssertEqual(t, schemeHTTPS, c.Protocol())
c.Request().Header.Reset() c.Request().Header.Reset()
utils.AssertEqual(t, "http", c.Protocol()) utils.AssertEqual(t, schemeHTTP, c.Protocol())
} }
// go test -run Test_Ctx_Protocol_UntrustedProxyRange // go test -run Test_Ctx_Protocol_UntrustedProxyRange
@ -2064,23 +2077,23 @@ func Test_Ctx_Protocol_UntrustedProxyRange(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{}) c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c) defer app.ReleaseCtx(c)
c.Request().Header.Set(HeaderXForwardedProto, "https") c.Request().Header.Set(HeaderXForwardedProto, schemeHTTPS)
utils.AssertEqual(t, "http", c.Protocol()) utils.AssertEqual(t, schemeHTTP, c.Protocol())
c.Request().Header.Reset() c.Request().Header.Reset()
c.Request().Header.Set(HeaderXForwardedProtocol, "https") c.Request().Header.Set(HeaderXForwardedProtocol, schemeHTTPS)
utils.AssertEqual(t, "http", c.Protocol()) utils.AssertEqual(t, schemeHTTP, c.Protocol())
c.Request().Header.Reset() c.Request().Header.Reset()
c.Request().Header.Set(HeaderXForwardedSsl, "on") c.Request().Header.Set(HeaderXForwardedSsl, "on")
utils.AssertEqual(t, "http", c.Protocol()) utils.AssertEqual(t, schemeHTTP, c.Protocol())
c.Request().Header.Reset() c.Request().Header.Reset()
c.Request().Header.Set(HeaderXUrlScheme, "https") c.Request().Header.Set(HeaderXUrlScheme, schemeHTTPS)
utils.AssertEqual(t, "http", c.Protocol()) utils.AssertEqual(t, schemeHTTP, c.Protocol())
c.Request().Header.Reset() c.Request().Header.Reset()
utils.AssertEqual(t, "http", c.Protocol()) utils.AssertEqual(t, schemeHTTP, c.Protocol())
} }
// go test -run Test_Ctx_Protocol_UnTrustedProxy // go test -run Test_Ctx_Protocol_UnTrustedProxy
@ -2090,23 +2103,23 @@ func Test_Ctx_Protocol_UnTrustedProxy(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{}) c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c) defer app.ReleaseCtx(c)
c.Request().Header.Set(HeaderXForwardedProto, "https") c.Request().Header.Set(HeaderXForwardedProto, schemeHTTPS)
utils.AssertEqual(t, "http", c.Protocol()) utils.AssertEqual(t, schemeHTTP, c.Protocol())
c.Request().Header.Reset() c.Request().Header.Reset()
c.Request().Header.Set(HeaderXForwardedProtocol, "https") c.Request().Header.Set(HeaderXForwardedProtocol, schemeHTTPS)
utils.AssertEqual(t, "http", c.Protocol()) utils.AssertEqual(t, schemeHTTP, c.Protocol())
c.Request().Header.Reset() c.Request().Header.Reset()
c.Request().Header.Set(HeaderXForwardedSsl, "on") c.Request().Header.Set(HeaderXForwardedSsl, "on")
utils.AssertEqual(t, "http", c.Protocol()) utils.AssertEqual(t, schemeHTTP, c.Protocol())
c.Request().Header.Reset() c.Request().Header.Reset()
c.Request().Header.Set(HeaderXUrlScheme, "https") c.Request().Header.Set(HeaderXUrlScheme, schemeHTTPS)
utils.AssertEqual(t, "http", c.Protocol()) utils.AssertEqual(t, schemeHTTP, c.Protocol())
c.Request().Header.Reset() c.Request().Header.Reset()
utils.AssertEqual(t, "http", c.Protocol()) utils.AssertEqual(t, schemeHTTP, c.Protocol())
} }
// go test -run Test_Ctx_Query // go test -run Test_Ctx_Query
@ -2224,7 +2237,12 @@ func Test_Ctx_SaveFile(t *testing.T) {
tempFile, err := os.CreateTemp(os.TempDir(), "test-") tempFile, err := os.CreateTemp(os.TempDir(), "test-")
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
defer os.Remove(tempFile.Name()) defer func(file *os.File) {
err := file.Close()
utils.AssertEqual(t, nil, err)
err = os.Remove(file.Name())
utils.AssertEqual(t, nil, err)
}(tempFile)
err = c.SaveFile(fh, tempFile.Name()) err = c.SaveFile(fh, tempFile.Name())
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
@ -2242,7 +2260,7 @@ func Test_Ctx_SaveFile(t *testing.T) {
_, err = ioWriter.Write([]byte("hello world")) _, err = ioWriter.Write([]byte("hello world"))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
writer.Close() utils.AssertEqual(t, nil, writer.Close())
req := httptest.NewRequest(MethodPost, "/test", body) req := httptest.NewRequest(MethodPost, "/test", body)
req.Header.Set("Content-Type", writer.FormDataContentType()) req.Header.Set("Content-Type", writer.FormDataContentType())
@ -2284,7 +2302,7 @@ func Test_Ctx_SaveFileToStorage(t *testing.T) {
_, err = ioWriter.Write([]byte("hello world")) _, err = ioWriter.Write([]byte("hello world"))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
writer.Close() utils.AssertEqual(t, nil, writer.Close())
req := httptest.NewRequest(MethodPost, "/test", body) req := httptest.NewRequest(MethodPost, "/test", body)
req.Header.Set("Content-Type", writer.FormDataContentType()) req.Header.Set("Content-Type", writer.FormDataContentType())
@ -2370,7 +2388,9 @@ func Test_Ctx_Download(t *testing.T) {
f, err := os.Open("./ctx.go") f, err := os.Open("./ctx.go")
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
defer f.Close() defer func() {
utils.AssertEqual(t, nil, f.Close())
}()
expect, err := io.ReadAll(f) expect, err := io.ReadAll(f)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
@ -2389,7 +2409,9 @@ func Test_Ctx_SendFile(t *testing.T) {
// fetch file content // fetch file content
f, err := os.Open("./ctx.go") f, err := os.Open("./ctx.go")
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
defer f.Close() defer func() {
utils.AssertEqual(t, nil, f.Close())
}()
expectFileContent, err := io.ReadAll(f) expectFileContent, err := io.ReadAll(f)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
// fetch file info for the not modified test case // fetch file info for the not modified test case
@ -2435,7 +2457,7 @@ func Test_Ctx_SendFile_404(t *testing.T) {
return err return err
}) })
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, StatusNotFound, resp.StatusCode) utils.AssertEqual(t, StatusNotFound, resp.StatusCode)
} }
@ -2473,11 +2495,11 @@ func Test_Ctx_SendFile_Immutable(t *testing.T) {
for _, endpoint := range endpointsForTest { for _, endpoint := range endpointsForTest {
t.Run(endpoint, func(t *testing.T) { t.Run(endpoint, func(t *testing.T) {
// 1st try // 1st try
resp, err := app.Test(httptest.NewRequest("GET", endpoint, nil)) resp, err := app.Test(httptest.NewRequest(MethodGet, endpoint, nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, StatusOK, resp.StatusCode) utils.AssertEqual(t, StatusOK, resp.StatusCode)
// 2nd try // 2nd try
resp, err = app.Test(httptest.NewRequest("GET", endpoint, nil)) resp, err = app.Test(httptest.NewRequest(MethodGet, endpoint, nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, StatusOK, resp.StatusCode) utils.AssertEqual(t, StatusOK, resp.StatusCode)
}) })
@ -2495,9 +2517,9 @@ func Test_Ctx_SendFile_RestoreOriginalURL(t *testing.T) {
return err return err
}) })
_, err1 := app.Test(httptest.NewRequest("GET", "/?test=true", nil)) _, err1 := app.Test(httptest.NewRequest(MethodGet, "/?test=true", nil))
// second request required to confirm with zero allocation // second request required to confirm with zero allocation
_, err2 := app.Test(httptest.NewRequest("GET", "/?test=true", nil)) _, err2 := app.Test(httptest.NewRequest(MethodGet, "/?test=true", nil))
utils.AssertEqual(t, nil, err1) utils.AssertEqual(t, nil, err1)
utils.AssertEqual(t, nil, err2) utils.AssertEqual(t, nil, err2)
@ -2893,12 +2915,12 @@ func Test_Ctx_Render(t *testing.T) {
err := c.Render("./.github/testdata/index.tmpl", Map{ err := c.Render("./.github/testdata/index.tmpl", Map{
"Title": "Hello, World!", "Title": "Hello, World!",
}) })
utils.AssertEqual(t, nil, err)
buf := bytebufferpool.Get() buf := bytebufferpool.Get()
_, _ = buf.WriteString("overwrite") _, _ = buf.WriteString("overwrite") //nolint:errcheck // This will never fail
defer bytebufferpool.Put(buf) defer bytebufferpool.Put(buf)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "<h1>Hello, World!</h1>", string(c.Response().Body())) utils.AssertEqual(t, "<h1>Hello, World!</h1>", string(c.Response().Body()))
err = c.Render("./.github/testdata/template-non-exists.html", nil) err = c.Render("./.github/testdata/template-non-exists.html", nil)
@ -2918,12 +2940,12 @@ func Test_Ctx_RenderWithoutLocals(t *testing.T) {
c.Locals("Title", "Hello, World!") c.Locals("Title", "Hello, World!")
defer app.ReleaseCtx(c) defer app.ReleaseCtx(c)
err := c.Render("./.github/testdata/index.tmpl", Map{}) err := c.Render("./.github/testdata/index.tmpl", Map{})
utils.AssertEqual(t, nil, err)
buf := bytebufferpool.Get() buf := bytebufferpool.Get()
_, _ = buf.WriteString("overwrite") _, _ = buf.WriteString("overwrite") //nolint:errcheck // This will never fail
defer bytebufferpool.Put(buf) defer bytebufferpool.Put(buf)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "<h1><no value></h1>", string(c.Response().Body())) utils.AssertEqual(t, "<h1><no value></h1>", string(c.Response().Body()))
} }
@ -2937,14 +2959,13 @@ func Test_Ctx_RenderWithLocals(t *testing.T) {
c.Locals("Title", "Hello, World!") c.Locals("Title", "Hello, World!")
defer app.ReleaseCtx(c) defer app.ReleaseCtx(c)
err := c.Render("./.github/testdata/index.tmpl", Map{}) err := c.Render("./.github/testdata/index.tmpl", Map{})
utils.AssertEqual(t, nil, err)
buf := bytebufferpool.Get() buf := bytebufferpool.Get()
_, _ = buf.WriteString("overwrite") _, _ = buf.WriteString("overwrite") //nolint:errcheck // This will never fail
defer bytebufferpool.Put(buf) defer bytebufferpool.Put(buf)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "<h1>Hello, World!</h1>", string(c.Response().Body())) utils.AssertEqual(t, "<h1>Hello, World!</h1>", string(c.Response().Body()))
} }
func Test_Ctx_RenderWithBind(t *testing.T) { func Test_Ctx_RenderWithBind(t *testing.T) {
@ -2959,14 +2980,13 @@ func Test_Ctx_RenderWithBind(t *testing.T) {
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
defer app.ReleaseCtx(c) defer app.ReleaseCtx(c)
err = c.Render("./.github/testdata/index.tmpl", Map{}) err = c.Render("./.github/testdata/index.tmpl", Map{})
utils.AssertEqual(t, nil, err)
buf := bytebufferpool.Get() buf := bytebufferpool.Get()
_, _ = buf.WriteString("overwrite") _, _ = buf.WriteString("overwrite") //nolint:errcheck // This will never fail
defer bytebufferpool.Put(buf) defer bytebufferpool.Put(buf)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "<h1>Hello, World!</h1>", string(c.Response().Body())) utils.AssertEqual(t, "<h1>Hello, World!</h1>", string(c.Response().Body()))
} }
func Test_Ctx_RenderWithOverwrittenBind(t *testing.T) { func Test_Ctx_RenderWithOverwrittenBind(t *testing.T) {
@ -2982,12 +3002,12 @@ func Test_Ctx_RenderWithOverwrittenBind(t *testing.T) {
err = c.Render("./.github/testdata/index.tmpl", Map{ err = c.Render("./.github/testdata/index.tmpl", Map{
"Title": "Hello from Fiber!", "Title": "Hello from Fiber!",
}) })
utils.AssertEqual(t, nil, err)
buf := bytebufferpool.Get() buf := bytebufferpool.Get()
_, _ = buf.WriteString("overwrite") _, _ = buf.WriteString("overwrite") //nolint:errcheck // This will never fail
defer bytebufferpool.Put(buf) defer bytebufferpool.Put(buf)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "<h1>Hello from Fiber!</h1>", string(c.Response().Body())) utils.AssertEqual(t, "<h1>Hello from Fiber!</h1>", string(c.Response().Body()))
} }
@ -3005,13 +3025,12 @@ func Test_Ctx_RenderWithBindLocals(t *testing.T) {
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
c.Locals("Summary", "Test") c.Locals("Summary", "Test")
defer app.ReleaseCtx(c) defer app.ReleaseCtx(c)
err = c.Render("./.github/testdata/template.tmpl", Map{}) err = c.Render("./.github/testdata/template.tmpl", Map{})
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "<h1>Hello, World! Test</h1>", string(c.Response().Body()))
utils.AssertEqual(t, "<h1>Hello, World! Test</h1>", string(c.Response().Body()))
} }
func Test_Ctx_RenderWithLocalsAndBinding(t *testing.T) { func Test_Ctx_RenderWithLocalsAndBinding(t *testing.T) {
@ -3027,11 +3046,12 @@ func Test_Ctx_RenderWithLocalsAndBinding(t *testing.T) {
c.Locals("Title", "This is a test.") c.Locals("Title", "This is a test.")
defer app.ReleaseCtx(c) defer app.ReleaseCtx(c)
err = c.Render("index.tmpl", Map{ err = c.Render("index.tmpl", Map{
"Title": "Hello, World!", "Title": "Hello, World!",
}) })
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "<h1>Hello, World!</h1>", string(c.Response().Body())) utils.AssertEqual(t, "<h1>Hello, World!</h1>", string(c.Response().Body()))
} }
@ -3060,8 +3080,8 @@ func Benchmark_Ctx_RenderWithLocalsAndBinding(b *testing.B) {
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
err = c.Render("template.tmpl", Map{}) err = c.Render("template.tmpl", Map{})
} }
utils.AssertEqual(b, nil, err) utils.AssertEqual(b, nil, err)
utils.AssertEqual(b, "<h1>Hello, World! Test</h1>", string(c.Response().Body())) utils.AssertEqual(b, "<h1>Hello, World! Test</h1>", string(c.Response().Body()))
} }
@ -3083,8 +3103,8 @@ func Benchmark_Ctx_RedirectToRoute(b *testing.B) {
"name": "fiber", "name": "fiber",
}) })
} }
utils.AssertEqual(b, nil, err) utils.AssertEqual(b, nil, err)
utils.AssertEqual(b, 302, c.Response().StatusCode()) utils.AssertEqual(b, 302, c.Response().StatusCode())
utils.AssertEqual(b, "/user/fiber", string(c.Response().Header.Peek(HeaderLocation))) utils.AssertEqual(b, "/user/fiber", string(c.Response().Header.Peek(HeaderLocation)))
} }
@ -3108,8 +3128,8 @@ func Benchmark_Ctx_RedirectToRouteWithQueries(b *testing.B) {
"queries": map[string]string{"a": "a", "b": "b"}, "queries": map[string]string{"a": "a", "b": "b"},
}) })
} }
utils.AssertEqual(b, nil, err) utils.AssertEqual(b, nil, err)
utils.AssertEqual(b, 302, c.Response().StatusCode()) utils.AssertEqual(b, 302, c.Response().StatusCode())
// analysis of query parameters with url parsing, since a map pass is always randomly ordered // analysis of query parameters with url parsing, since a map pass is always randomly ordered
location, err := url.Parse(string(c.Response().Header.Peek(HeaderLocation))) location, err := url.Parse(string(c.Response().Header.Peek(HeaderLocation)))
@ -3138,8 +3158,8 @@ func Benchmark_Ctx_RenderLocals(b *testing.B) {
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
err = c.Render("index.tmpl", Map{}) err = c.Render("index.tmpl", Map{})
} }
utils.AssertEqual(b, nil, err) utils.AssertEqual(b, nil, err)
utils.AssertEqual(b, "<h1>Hello, World!</h1>", string(c.Response().Body())) utils.AssertEqual(b, "<h1>Hello, World!</h1>", string(c.Response().Body()))
} }
@ -3164,8 +3184,8 @@ func Benchmark_Ctx_RenderBind(b *testing.B) {
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
err = c.Render("index.tmpl", Map{}) err = c.Render("index.tmpl", Map{})
} }
utils.AssertEqual(b, nil, err) utils.AssertEqual(b, nil, err)
utils.AssertEqual(b, "<h1>Hello, World!</h1>", string(c.Response().Body())) utils.AssertEqual(b, "<h1>Hello, World!</h1>", string(c.Response().Body()))
} }
@ -3191,8 +3211,7 @@ func Test_Ctx_RestartRouting(t *testing.T) {
func Test_Ctx_RestartRoutingWithChangedPath(t *testing.T) { func Test_Ctx_RestartRoutingWithChangedPath(t *testing.T) {
t.Parallel() t.Parallel()
app := New() app := New()
executedOldHandler := false var executedOldHandler, executedNewHandler bool
executedNewHandler := false
app.Get("/old", func(c *Ctx) error { app.Get("/old", func(c *Ctx) error {
c.Path("/new") c.Path("/new")
@ -3242,10 +3261,18 @@ type testTemplateEngine struct {
func (t *testTemplateEngine) Render(w io.Writer, name string, bind interface{}, layout ...string) error { func (t *testTemplateEngine) Render(w io.Writer, name string, bind interface{}, layout ...string) error {
if len(layout) == 0 { if len(layout) == 0 {
return t.templates.ExecuteTemplate(w, name, bind) if err := t.templates.ExecuteTemplate(w, name, bind); err != nil {
return fmt.Errorf("failed to execute template without layout: %w", err)
} }
_ = t.templates.ExecuteTemplate(w, name, bind) return nil
return t.templates.ExecuteTemplate(w, layout[0], bind) }
if err := t.templates.ExecuteTemplate(w, name, bind); err != nil {
return fmt.Errorf("failed to execute template: %w", err)
}
if err := t.templates.ExecuteTemplate(w, layout[0], bind); err != nil {
return fmt.Errorf("failed to execute template with layout: %w", err)
}
return nil
} }
func (t *testTemplateEngine) Load() error { func (t *testTemplateEngine) Load() error {
@ -3324,7 +3351,6 @@ func Benchmark_Ctx_Get_Location_From_Route(b *testing.B) {
} }
utils.AssertEqual(b, "/user/fiber", location) utils.AssertEqual(b, "/user/fiber", location)
utils.AssertEqual(b, nil, err) utils.AssertEqual(b, nil, err)
} }
// go test -run Test_Ctx_Get_Location_From_Route_name // go test -run Test_Ctx_Get_Location_From_Route_name
@ -3407,11 +3433,11 @@ func Test_Ctx_Get_Location_From_Route_name_Optional_greedy_one_param(t *testing.
type errorTemplateEngine struct{} type errorTemplateEngine struct{}
func (t errorTemplateEngine) Render(w io.Writer, name string, bind interface{}, layout ...string) error { func (errorTemplateEngine) Render(_ io.Writer, _ string, _ interface{}, _ ...string) error {
return errors.New("errorTemplateEngine") return errors.New("errorTemplateEngine")
} }
func (t errorTemplateEngine) Load() error { return nil } func (errorTemplateEngine) Load() error { return nil }
// go test -run Test_Ctx_Render_Engine_Error // go test -run Test_Ctx_Render_Engine_Error
func Test_Ctx_Render_Engine_Error(t *testing.T) { func Test_Ctx_Render_Engine_Error(t *testing.T) {
@ -3429,7 +3455,10 @@ func Test_Ctx_Render_Go_Template(t *testing.T) {
t.Parallel() t.Parallel()
file, err := os.CreateTemp(os.TempDir(), "fiber") file, err := os.CreateTemp(os.TempDir(), "fiber")
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
defer os.Remove(file.Name()) defer func() {
err := os.Remove(file.Name())
utils.AssertEqual(t, nil, err)
}()
_, err = file.Write([]byte("template")) _, err = file.Write([]byte("template"))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
@ -3821,7 +3850,7 @@ func Test_Ctx_QueryParser(t *testing.T) {
} }
rq := new(RequiredQuery) rq := new(RequiredQuery)
c.Request().URI().SetQueryString("") c.Request().URI().SetQueryString("")
utils.AssertEqual(t, "name is empty", c.QueryParser(rq).Error()) utils.AssertEqual(t, "failed to decode: name is empty", c.QueryParser(rq).Error())
type ArrayQuery struct { type ArrayQuery struct {
Data []string Data []string
@ -3837,7 +3866,7 @@ func Test_Ctx_QueryParser_WithSetParserDecoder(t *testing.T) {
t.Parallel() t.Parallel()
type NonRFCTime time.Time type NonRFCTime time.Time
NonRFCConverter := func(value string) reflect.Value { nonRFCConverter := func(value string) reflect.Value {
if v, err := time.Parse("2006-01-02", value); err == nil { if v, err := time.Parse("2006-01-02", value); err == nil {
return reflect.ValueOf(v) return reflect.ValueOf(v)
} }
@ -3846,7 +3875,7 @@ func Test_Ctx_QueryParser_WithSetParserDecoder(t *testing.T) {
nonRFCTime := ParserType{ nonRFCTime := ParserType{
Customtype: NonRFCTime{}, Customtype: NonRFCTime{},
Converter: NonRFCConverter, Converter: nonRFCConverter,
} }
SetParserDecoder(ParserConfig{ SetParserDecoder(ParserConfig{
@ -3872,7 +3901,6 @@ func Test_Ctx_QueryParser_WithSetParserDecoder(t *testing.T) {
c.Request().URI().SetQueryString("date=2021-04-10&title=CustomDateTest&Body=October") c.Request().URI().SetQueryString("date=2021-04-10&title=CustomDateTest&Body=October")
utils.AssertEqual(t, nil, c.QueryParser(q)) utils.AssertEqual(t, nil, c.QueryParser(q))
fmt.Println(q.Date, "q.Date")
utils.AssertEqual(t, "CustomDateTest", q.Title) utils.AssertEqual(t, "CustomDateTest", q.Title)
date := fmt.Sprintf("%v", q.Date) date := fmt.Sprintf("%v", q.Date)
utils.AssertEqual(t, "{0 63753609600 <nil>}", date) utils.AssertEqual(t, "{0 63753609600 <nil>}", date)
@ -3907,7 +3935,7 @@ func Test_Ctx_QueryParser_Schema(t *testing.T) {
c.Request().URI().SetQueryString("namex=tom&nested.age=10") c.Request().URI().SetQueryString("namex=tom&nested.age=10")
q = new(Query1) q = new(Query1)
utils.AssertEqual(t, "name is empty", c.QueryParser(q).Error()) utils.AssertEqual(t, "failed to decode: name is empty", c.QueryParser(q).Error())
c.Request().URI().SetQueryString("name=tom&nested.agex=10") c.Request().URI().SetQueryString("name=tom&nested.agex=10")
q = new(Query1) q = new(Query1)
@ -3915,7 +3943,7 @@ func Test_Ctx_QueryParser_Schema(t *testing.T) {
c.Request().URI().SetQueryString("name=tom&test.age=10") c.Request().URI().SetQueryString("name=tom&test.age=10")
q = new(Query1) q = new(Query1)
utils.AssertEqual(t, "nested is empty", c.QueryParser(q).Error()) utils.AssertEqual(t, "failed to decode: nested is empty", c.QueryParser(q).Error())
type Query2 struct { type Query2 struct {
Name string `query:"name"` Name string `query:"name"`
@ -3933,11 +3961,11 @@ func Test_Ctx_QueryParser_Schema(t *testing.T) {
c.Request().URI().SetQueryString("nested.agex=10") c.Request().URI().SetQueryString("nested.agex=10")
q2 = new(Query2) q2 = new(Query2)
utils.AssertEqual(t, "nested.age is empty", c.QueryParser(q2).Error()) utils.AssertEqual(t, "failed to decode: nested.age is empty", c.QueryParser(q2).Error())
c.Request().URI().SetQueryString("nested.agex=10") c.Request().URI().SetQueryString("nested.agex=10")
q2 = new(Query2) q2 = new(Query2)
utils.AssertEqual(t, "nested.age is empty", c.QueryParser(q2).Error()) utils.AssertEqual(t, "failed to decode: nested.age is empty", c.QueryParser(q2).Error())
type Node struct { type Node struct {
Value int `query:"val,required"` Value int `query:"val,required"`
@ -3951,7 +3979,7 @@ func Test_Ctx_QueryParser_Schema(t *testing.T) {
c.Request().URI().SetQueryString("next.val=2") c.Request().URI().SetQueryString("next.val=2")
n = new(Node) n = new(Node)
utils.AssertEqual(t, "val is empty", c.QueryParser(n).Error()) utils.AssertEqual(t, "failed to decode: val is empty", c.QueryParser(n).Error())
c.Request().URI().SetQueryString("val=3&next.value=2") c.Request().URI().SetQueryString("val=3&next.value=2")
n = new(Node) n = new(Node)
@ -4057,7 +4085,7 @@ func Test_Ctx_ReqHeaderParser(t *testing.T) {
} }
rh := new(RequiredHeader) rh := new(RequiredHeader)
c.Request().Header.Del("name") c.Request().Header.Del("name")
utils.AssertEqual(t, "name is empty", c.ReqHeaderParser(rh).Error()) utils.AssertEqual(t, "failed to decode: name is empty", c.ReqHeaderParser(rh).Error())
} }
// go test -run Test_Ctx_ReqHeaderParser_WithSetParserDecoder -v // go test -run Test_Ctx_ReqHeaderParser_WithSetParserDecoder -v
@ -4065,7 +4093,7 @@ func Test_Ctx_ReqHeaderParser_WithSetParserDecoder(t *testing.T) {
t.Parallel() t.Parallel()
type NonRFCTime time.Time type NonRFCTime time.Time
NonRFCConverter := func(value string) reflect.Value { nonRFCConverter := func(value string) reflect.Value {
if v, err := time.Parse("2006-01-02", value); err == nil { if v, err := time.Parse("2006-01-02", value); err == nil {
return reflect.ValueOf(v) return reflect.ValueOf(v)
} }
@ -4074,7 +4102,7 @@ func Test_Ctx_ReqHeaderParser_WithSetParserDecoder(t *testing.T) {
nonRFCTime := ParserType{ nonRFCTime := ParserType{
Customtype: NonRFCTime{}, Customtype: NonRFCTime{},
Converter: NonRFCConverter, Converter: nonRFCConverter,
} }
SetParserDecoder(ParserConfig{ SetParserDecoder(ParserConfig{
@ -4103,7 +4131,6 @@ func Test_Ctx_ReqHeaderParser_WithSetParserDecoder(t *testing.T) {
c.Request().Header.Add("Body", "October") c.Request().Header.Add("Body", "October")
utils.AssertEqual(t, nil, c.ReqHeaderParser(r)) utils.AssertEqual(t, nil, c.ReqHeaderParser(r))
fmt.Println(r.Date, "q.Date")
utils.AssertEqual(t, "CustomDateTest", r.Title) utils.AssertEqual(t, "CustomDateTest", r.Title)
date := fmt.Sprintf("%v", r.Date) date := fmt.Sprintf("%v", r.Date)
utils.AssertEqual(t, "{0 63753609600 <nil>}", date) utils.AssertEqual(t, "{0 63753609600 <nil>}", date)
@ -4140,7 +4167,7 @@ func Test_Ctx_ReqHeaderParser_Schema(t *testing.T) {
c.Request().Header.Del("Name") c.Request().Header.Del("Name")
q = new(Header1) q = new(Header1)
utils.AssertEqual(t, "Name is empty", c.ReqHeaderParser(q).Error()) utils.AssertEqual(t, "failed to decode: Name is empty", c.ReqHeaderParser(q).Error())
c.Request().Header.Add("Name", "tom") c.Request().Header.Add("Name", "tom")
c.Request().Header.Del("Nested.Age") c.Request().Header.Del("Nested.Age")
@ -4150,7 +4177,7 @@ func Test_Ctx_ReqHeaderParser_Schema(t *testing.T) {
c.Request().Header.Del("Nested.Agex") c.Request().Header.Del("Nested.Agex")
q = new(Header1) q = new(Header1)
utils.AssertEqual(t, "Nested is empty", c.ReqHeaderParser(q).Error()) utils.AssertEqual(t, "failed to decode: Nested is empty", c.ReqHeaderParser(q).Error())
c.Request().Header.Del("Nested.Agex") c.Request().Header.Del("Nested.Agex")
c.Request().Header.Del("Name") c.Request().Header.Del("Name")
@ -4176,7 +4203,7 @@ func Test_Ctx_ReqHeaderParser_Schema(t *testing.T) {
c.Request().Header.Del("Nested.Age") c.Request().Header.Del("Nested.Age")
c.Request().Header.Add("Nested.Agex", "10") c.Request().Header.Add("Nested.Agex", "10")
h2 = new(Header2) h2 = new(Header2)
utils.AssertEqual(t, "Nested.age is empty", c.ReqHeaderParser(h2).Error()) utils.AssertEqual(t, "failed to decode: Nested.age is empty", c.ReqHeaderParser(h2).Error())
type Node struct { type Node struct {
Value int `reqHeader:"Val,required"` Value int `reqHeader:"Val,required"`
@ -4191,7 +4218,7 @@ func Test_Ctx_ReqHeaderParser_Schema(t *testing.T) {
c.Request().Header.Del("Val") c.Request().Header.Del("Val")
n = new(Node) n = new(Node)
utils.AssertEqual(t, "Val is empty", c.ReqHeaderParser(n).Error()) utils.AssertEqual(t, "failed to decode: Val is empty", c.ReqHeaderParser(n).Error())
c.Request().Header.Add("Val", "3") c.Request().Header.Add("Val", "3")
c.Request().Header.Del("Next.Val") c.Request().Header.Del("Next.Val")
@ -4628,8 +4655,9 @@ func Test_Ctx_RepeatParserWithSameStruct(t *testing.T) {
var gzipJSON bytes.Buffer var gzipJSON bytes.Buffer
w := gzip.NewWriter(&gzipJSON) w := gzip.NewWriter(&gzipJSON)
_, _ = w.Write([]byte(`{"body_param":"body_param"}`)) _, _ = w.Write([]byte(`{"body_param":"body_param"}`)) //nolint:errcheck // This will never fail
_ = w.Close() err := w.Close()
utils.AssertEqual(t, nil, err)
c.Request().Header.SetContentType(MIMEApplicationJSON) c.Request().Header.SetContentType(MIMEApplicationJSON)
c.Request().Header.Set(HeaderContentEncoding, "gzip") c.Request().Header.Set(HeaderContentEncoding, "gzip")
c.Request().SetBody(gzipJSON.Bytes()) c.Request().SetBody(gzipJSON.Bytes())

View File

@ -1,11 +1,10 @@
package fiber package fiber
import ( import (
"encoding/json"
"errors" "errors"
"testing" "testing"
jerrors "encoding/json"
"github.com/gofiber/fiber/v2/internal/schema" "github.com/gofiber/fiber/v2/internal/schema"
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
) )
@ -36,42 +35,42 @@ func TestMultiError(t *testing.T) {
func TestInvalidUnmarshalError(t *testing.T) { func TestInvalidUnmarshalError(t *testing.T) {
t.Parallel() t.Parallel()
var e *jerrors.InvalidUnmarshalError var e *json.InvalidUnmarshalError
ok := errors.As(&InvalidUnmarshalError{}, &e) ok := errors.As(&InvalidUnmarshalError{}, &e)
utils.AssertEqual(t, true, ok) utils.AssertEqual(t, true, ok)
} }
func TestMarshalerError(t *testing.T) { func TestMarshalerError(t *testing.T) {
t.Parallel() t.Parallel()
var e *jerrors.MarshalerError var e *json.MarshalerError
ok := errors.As(&MarshalerError{}, &e) ok := errors.As(&MarshalerError{}, &e)
utils.AssertEqual(t, true, ok) utils.AssertEqual(t, true, ok)
} }
func TestSyntaxError(t *testing.T) { func TestSyntaxError(t *testing.T) {
t.Parallel() t.Parallel()
var e *jerrors.SyntaxError var e *json.SyntaxError
ok := errors.As(&SyntaxError{}, &e) ok := errors.As(&SyntaxError{}, &e)
utils.AssertEqual(t, true, ok) utils.AssertEqual(t, true, ok)
} }
func TestUnmarshalTypeError(t *testing.T) { func TestUnmarshalTypeError(t *testing.T) {
t.Parallel() t.Parallel()
var e *jerrors.UnmarshalTypeError var e *json.UnmarshalTypeError
ok := errors.As(&UnmarshalTypeError{}, &e) ok := errors.As(&UnmarshalTypeError{}, &e)
utils.AssertEqual(t, true, ok) utils.AssertEqual(t, true, ok)
} }
func TestUnsupportedTypeError(t *testing.T) { func TestUnsupportedTypeError(t *testing.T) {
t.Parallel() t.Parallel()
var e *jerrors.UnsupportedTypeError var e *json.UnsupportedTypeError
ok := errors.As(&UnsupportedTypeError{}, &e) ok := errors.As(&UnsupportedTypeError{}, &e)
utils.AssertEqual(t, true, ok) utils.AssertEqual(t, true, ok)
} }
func TestUnsupportedValeError(t *testing.T) { func TestUnsupportedValeError(t *testing.T) {
t.Parallel() t.Parallel()
var e *jerrors.UnsupportedValueError var e *json.UnsupportedValueError
ok := errors.As(&UnsupportedValueError{}, &e) ok := errors.As(&UnsupportedValueError{}, &e)
utils.AssertEqual(t, true, ok) utils.AssertEqual(t, true, ok)
} }

View File

@ -168,7 +168,6 @@ func (grp *Group) Group(prefix string, handlers ...Handler) Router {
} }
return newGrp return newGrp
} }
// Route is used to define routes with a common prefix inside the common function. // Route is used to define routes with a common prefix inside the common function.

View File

@ -20,13 +20,13 @@ import (
"unsafe" "unsafe"
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
"github.com/valyala/bytebufferpool" "github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
/* #nosec */ // getTLSConfig returns a net listener's tls config
// getTlsConfig returns a net listener's tls config func getTLSConfig(ln net.Listener) *tls.Config {
func getTlsConfig(ln net.Listener) *tls.Config {
// Get listener type // Get listener type
pointer := reflect.ValueOf(ln) pointer := reflect.ValueOf(ln)
@ -37,12 +37,16 @@ func getTlsConfig(ln net.Listener) *tls.Config {
// Get private field from value // Get private field from value
if field := val.FieldByName("config"); field.Type() != nil { if field := val.FieldByName("config"); field.Type() != nil {
// Copy value from pointer field (unsafe) // Copy value from pointer field (unsafe)
newval := reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())) // #nosec G103 newval := reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())) //nolint:gosec // Probably the only way to extract the *tls.Config from a net.Listener. TODO: Verify there really is no easier way without using unsafe.
if newval.Type() != nil { if newval.Type() != nil {
// Get element from pointer // Get element from pointer
if elem := newval.Elem(); elem.Type() != nil { if elem := newval.Elem(); elem.Type() != nil {
// Cast value to *tls.Config // Cast value to *tls.Config
return elem.Interface().(*tls.Config) c, ok := elem.Interface().(*tls.Config)
if !ok {
panic(fmt.Errorf("failed to type-assert to *tls.Config"))
}
return c
} }
} }
} }
@ -53,19 +57,21 @@ func getTlsConfig(ln net.Listener) *tls.Config {
} }
// readContent opens a named file and read content from it // readContent opens a named file and read content from it
func readContent(rf io.ReaderFrom, name string) (n int64, err error) { func readContent(rf io.ReaderFrom, name string) (int64, error) {
// Read file // Read file
f, err := os.Open(filepath.Clean(name)) f, err := os.Open(filepath.Clean(name))
if err != nil { if err != nil {
return 0, err return 0, fmt.Errorf("failed to open: %w", err)
} }
// #nosec G307
defer func() { defer func() {
if err = f.Close(); err != nil { if err = f.Close(); err != nil {
log.Printf("Error closing file: %s\n", err) log.Printf("Error closing file: %s\n", err)
} }
}() }()
return rf.ReadFrom(f) if n, err := rf.ReadFrom(f); err != nil {
return n, fmt.Errorf("failed to read: %w", err)
}
return 0, nil
} }
// quoteString escape special characters in a given string // quoteString escape special characters in a given string
@ -78,7 +84,8 @@ func (app *App) quoteString(raw string) string {
} }
// Scan stack if other methods match the request // Scan stack if other methods match the request
func (app *App) methodExist(ctx *Ctx) (exist bool) { func (app *App) methodExist(ctx *Ctx) bool {
var exists bool
methods := app.config.RequestMethods methods := app.config.RequestMethods
for i := 0; i < len(methods); i++ { for i := 0; i < len(methods); i++ {
// Skip original method // Skip original method
@ -108,7 +115,7 @@ func (app *App) methodExist(ctx *Ctx) (exist bool) {
// No match, next route // No match, next route
if match { if match {
// We matched // We matched
exist = true exists = true
// Add method to Allow header // Add method to Allow header
ctx.Append(HeaderAllow, methods[i]) ctx.Append(HeaderAllow, methods[i])
// Break stack loop // Break stack loop
@ -116,7 +123,7 @@ func (app *App) methodExist(ctx *Ctx) (exist bool) {
} }
} }
} }
return return exists
} }
// uniqueRouteStack drop all not unique routes from the slice // uniqueRouteStack drop all not unique routes from the slice
@ -146,7 +153,7 @@ func defaultString(value string, defaultValue []string) string {
const normalizedHeaderETag = "Etag" const normalizedHeaderETag = "Etag"
// Generate and set ETag header to response // Generate and set ETag header to response
func setETag(c *Ctx, weak bool) { func setETag(c *Ctx, weak bool) { //nolint: revive // Accepting a bool param is fine here
// Don't generate ETags for invalid responses // Don't generate ETags for invalid responses
if c.fasthttp.Response.StatusCode() != StatusOK { if c.fasthttp.Response.StatusCode() != StatusOK {
return return
@ -160,7 +167,8 @@ func setETag(c *Ctx, weak bool) {
clientEtag := c.Get(HeaderIfNoneMatch) clientEtag := c.Get(HeaderIfNoneMatch)
// Generate ETag for response // Generate ETag for response
crc32q := crc32.MakeTable(0xD5828281) const pol = 0xD5828281
crc32q := crc32.MakeTable(pol)
etag := fmt.Sprintf("\"%d-%v\"", len(body), crc32.Checksum(body, crc32q)) etag := fmt.Sprintf("\"%d-%v\"", len(body), crc32.Checksum(body, crc32q))
// Enable weak tag // Enable weak tag
@ -173,7 +181,9 @@ func setETag(c *Ctx, weak bool) {
// Check if server's ETag is weak // Check if server's ETag is weak
if clientEtag[2:] == etag || clientEtag[2:] == etag[2:] { if clientEtag[2:] == etag || clientEtag[2:] == etag[2:] {
// W/1 == 1 || W/1 == W/1 // W/1 == 1 || W/1 == W/1
_ = c.SendStatus(StatusNotModified) if err := c.SendStatus(StatusNotModified); err != nil {
log.Printf("setETag: failed to SendStatus: %v\n", err)
}
c.fasthttp.ResetBody() c.fasthttp.ResetBody()
return return
} }
@ -183,7 +193,9 @@ func setETag(c *Ctx, weak bool) {
} }
if strings.Contains(clientEtag, etag) { if strings.Contains(clientEtag, etag) {
// 1 == 1 // 1 == 1
_ = c.SendStatus(StatusNotModified) if err := c.SendStatus(StatusNotModified); err != nil {
log.Printf("setETag: failed to SendStatus: %v\n", err)
}
c.fasthttp.ResetBody() c.fasthttp.ResetBody()
return return
} }
@ -239,7 +251,7 @@ func getOffer(header string, offers ...string) string {
return "" return ""
} }
func matchEtag(s string, etag string) bool { func matchEtag(s, etag string) bool {
if s == etag || s == "W/"+etag || "W/"+s == etag { if s == etag || s == "W/"+etag || "W/"+s == etag {
return true return true
} }
@ -254,12 +266,12 @@ func (app *App) isEtagStale(etag string, noneMatchBytes []byte) bool {
// https://github.com/jshttp/fresh/blob/10e0471669dbbfbfd8de65bc6efac2ddd0bfa057/index.js#L110 // https://github.com/jshttp/fresh/blob/10e0471669dbbfbfd8de65bc6efac2ddd0bfa057/index.js#L110
for i := range noneMatchBytes { for i := range noneMatchBytes {
switch noneMatchBytes[i] { switch noneMatchBytes[i] {
case 0x20: case 0x20: //nolint:gomnd // This is a space (" ")
if start == end { if start == end {
start = i + 1 start = i + 1
end = i + 1 end = i + 1
} }
case 0x2c: case 0x2c: //nolint:gomnd // This is a comma (",")
if matchEtag(app.getString(noneMatchBytes[start:end]), etag) { if matchEtag(app.getString(noneMatchBytes[start:end]), etag) {
return false return false
} }
@ -273,7 +285,7 @@ func (app *App) isEtagStale(etag string, noneMatchBytes []byte) bool {
return !matchEtag(app.getString(noneMatchBytes[start:end]), etag) return !matchEtag(app.getString(noneMatchBytes[start:end]), etag)
} }
func parseAddr(raw string) (host, port string) { func parseAddr(raw string) (string, string) { //nolint:revive // Returns (host, port)
if i := strings.LastIndex(raw, ":"); i != -1 { if i := strings.LastIndex(raw, ":"); i != -1 {
return raw[:i], raw[i+1:] return raw[:i], raw[i+1:]
} }
@ -313,21 +325,21 @@ type testConn struct {
w bytes.Buffer w bytes.Buffer
} }
func (c *testConn) Read(b []byte) (int, error) { return c.r.Read(b) } func (c *testConn) Read(b []byte) (int, error) { return c.r.Read(b) } //nolint:wrapcheck // This must not be wrapped
func (c *testConn) Write(b []byte) (int, error) { return c.w.Write(b) } func (c *testConn) Write(b []byte) (int, error) { return c.w.Write(b) } //nolint:wrapcheck // This must not be wrapped
func (c *testConn) Close() error { return nil } func (*testConn) Close() error { return nil }
func (c *testConn) LocalAddr() net.Addr { return &net.TCPAddr{Port: 0, Zone: "", IP: net.IPv4zero} } func (*testConn) LocalAddr() net.Addr { return &net.TCPAddr{Port: 0, Zone: "", IP: net.IPv4zero} }
func (c *testConn) RemoteAddr() net.Addr { return &net.TCPAddr{Port: 0, Zone: "", IP: net.IPv4zero} } func (*testConn) RemoteAddr() net.Addr { return &net.TCPAddr{Port: 0, Zone: "", IP: net.IPv4zero} }
func (c *testConn) SetDeadline(_ time.Time) error { return nil } func (*testConn) SetDeadline(_ time.Time) error { return nil }
func (c *testConn) SetReadDeadline(_ time.Time) error { return nil } func (*testConn) SetReadDeadline(_ time.Time) error { return nil }
func (c *testConn) SetWriteDeadline(_ time.Time) error { return nil } func (*testConn) SetWriteDeadline(_ time.Time) error { return nil }
var getStringImmutable = func(b []byte) string { func getStringImmutable(b []byte) string {
return string(b) return string(b)
} }
var getBytesImmutable = func(s string) (b []byte) { func getBytesImmutable(s string) []byte {
return []byte(s) return []byte(s)
} }
@ -335,6 +347,7 @@ var getBytesImmutable = func(s string) (b []byte) {
func (app *App) methodInt(s string) int { func (app *App) methodInt(s string) int {
// For better performance // For better performance
if len(app.configured.RequestMethods) == 0 { if len(app.configured.RequestMethods) == 0 {
//nolint:gomnd // TODO: Use iota instead
switch s { switch s {
case MethodGet: case MethodGet:
return 0 return 0
@ -391,8 +404,7 @@ func IsMethodIdempotent(m string) bool {
} }
switch m { switch m {
case MethodPut, case MethodPut, MethodDelete:
MethodDelete:
return true return true
default: default:
return false return false
@ -714,7 +726,7 @@ const (
ConstraintBool = "bool" ConstraintBool = "bool"
ConstraintFloat = "float" ConstraintFloat = "float"
ConstraintAlpha = "alpha" ConstraintAlpha = "alpha"
ConstraintGuid = "guid" ConstraintGuid = "guid" //nolint:revive,stylecheck // TODO: Rename to "ConstraintGUID" in v3
ConstraintMinLen = "minLen" ConstraintMinLen = "minLen"
ConstraintMaxLen = "maxLen" ConstraintMaxLen = "maxLen"
ConstraintLen = "len" ConstraintLen = "len"

View File

@ -11,6 +11,7 @@ import (
"time" "time"
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )

View File

@ -1,14 +1,20 @@
package fiber package fiber
import (
"log"
)
// OnRouteHandler Handlers define a function to create hooks for Fiber. // OnRouteHandler Handlers define a function to create hooks for Fiber.
type OnRouteHandler = func(Route) error type (
type OnNameHandler = OnRouteHandler OnRouteHandler = func(Route) error
type OnGroupHandler = func(Group) error OnNameHandler = OnRouteHandler
type OnGroupNameHandler = OnGroupHandler OnGroupHandler = func(Group) error
type OnListenHandler = func() error OnGroupNameHandler = OnGroupHandler
type OnShutdownHandler = OnListenHandler OnListenHandler = func() error
type OnForkHandler = func(int) error OnShutdownHandler = OnListenHandler
type OnMountHandler = func(*App) error OnForkHandler = func(int) error
OnMountHandler = func(*App) error
)
// Hooks is a struct to use it with App. // Hooks is a struct to use it with App.
type Hooks struct { type Hooks struct {
@ -180,13 +186,17 @@ func (h *Hooks) executeOnListenHooks() error {
func (h *Hooks) executeOnShutdownHooks() { func (h *Hooks) executeOnShutdownHooks() {
for _, v := range h.onShutdown { for _, v := range h.onShutdown {
_ = v() if err := v(); err != nil {
log.Printf("failed to call shutdown hook: %v\n", err)
}
} }
} }
func (h *Hooks) executeOnForkHooks(pid int) { func (h *Hooks) executeOnForkHooks(pid int) {
for _, v := range h.onFork { for _, v := range h.onFork {
_ = v(pid) if err := v(pid); err != nil {
log.Printf("failed to call fork hook: %v\n", err)
}
} }
} }

View File

@ -7,10 +7,11 @@ import (
"time" "time"
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
"github.com/valyala/bytebufferpool" "github.com/valyala/bytebufferpool"
) )
var testSimpleHandler = func(c *Ctx) error { func testSimpleHandler(c *Ctx) error {
return c.SendString("simple") return c.SendString("simple")
} }

View File

@ -10,7 +10,7 @@ var defaultPool = sync.Pool{
// AcquireDict acquire new dict. // AcquireDict acquire new dict.
func AcquireDict() *Dict { func AcquireDict() *Dict {
return defaultPool.Get().(*Dict) // nolint:forcetypeassert return defaultPool.Get().(*Dict)
} }
// ReleaseDict release dict. // ReleaseDict release dict.

View File

@ -366,7 +366,7 @@ func HostDev(combineWith ...string) string {
// getSysctrlEnv sets LC_ALL=C in a list of env vars for use when running // getSysctrlEnv sets LC_ALL=C in a list of env vars for use when running
// sysctl commands (see DoSysctrl). // sysctl commands (see DoSysctrl).
func getSysctrlEnv(env []string) []string { func getSysctrlEnv(env []string) []string {
foundLC := false var foundLC bool
for i, line := range env { for i, line := range env {
if strings.HasPrefix(line, "LC_ALL") { if strings.HasPrefix(line, "LC_ALL") {
env[i] = "LC_ALL=C" env[i] = "LC_ALL=C"

View File

@ -6,8 +6,7 @@ import (
"github.com/gofiber/fiber/v2/internal/gopsutil/common" "github.com/gofiber/fiber/v2/internal/gopsutil/common"
) )
//lint:ignore U1000 we need this elsewhere var invoke common.Invoker = common.Invoke{} //nolint:unused // We use this only for some OS'es
var invoke common.Invoker = common.Invoke{} //nolint:all
// Memory usage statistics. Total, Available and Used contain numbers of bytes // Memory usage statistics. Total, Available and Used contain numbers of bytes
// for human consumption. // for human consumption.

View File

@ -86,7 +86,6 @@ func SwapMemory() (*SwapMemoryStat, error) {
} }
// Constants from vm/vm_param.h // Constants from vm/vm_param.h
// nolint: golint
const ( const (
XSWDEV_VERSION11 = 1 XSWDEV_VERSION11 = 1
XSWDEV_VERSION = 2 XSWDEV_VERSION = 2

View File

@ -57,10 +57,12 @@ func fillFromMeminfoWithContext(ctx context.Context) (*VirtualMemoryStat, *Virtu
lines, _ := common.ReadLines(filename) lines, _ := common.ReadLines(filename)
// flag if MemAvailable is in /proc/meminfo (kernel 3.14+) // flag if MemAvailable is in /proc/meminfo (kernel 3.14+)
memavail := false var (
activeFile := false // "Active(file)" not available: 2.6.28 / Dec 2008 memavail bool
inactiveFile := false // "Inactive(file)" not available: 2.6.28 / Dec 2008 activeFile bool // "Active(file)" not available: 2.6.28 / Dec 2008
sReclaimable := false // "SReclaimable:" not available: 2.6.19 / Nov 2006 inactiveFile bool // "Inactive(file)" not available: 2.6.28 / Dec 2008
sReclaimable bool // "SReclaimable:" not available: 2.6.19 / Nov 2006
)
ret := &VirtualMemoryStat{} ret := &VirtualMemoryStat{}
retEx := &VirtualMemoryExStat{} retEx := &VirtualMemoryExStat{}

View File

@ -168,7 +168,7 @@ func rwArray(dst jsWriter, src *Reader) (n int, err error) {
if err != nil { if err != nil {
return return
} }
comma := false var comma bool
for i := uint32(0); i < sz; i++ { for i := uint32(0); i < sz; i++ {
if comma { if comma {
err = dst.WriteByte(',') err = dst.WriteByte(',')

View File

@ -163,7 +163,7 @@ func (c *cache) createField(field reflect.StructField, parentAlias string) *fiel
} }
// Check if the type is supported and don't cache it if not. // Check if the type is supported and don't cache it if not.
// First let's get the basic type. // First let's get the basic type.
isSlice, isStruct := false, false var isSlice, isStruct bool
ft := field.Type ft := field.Type
m := isTextUnmarshaler(reflect.Zero(ft)) m := isTextUnmarshaler(reflect.Zero(ft))
if ft.Kind() == reflect.Ptr { if ft.Kind() == reflect.Ptr {

View File

@ -149,13 +149,13 @@ func Benchmark_Storage_Memory(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
for _, key := range keys { for _, key := range keys {
d.Set(key, value, ttl) _ = d.Set(key, value, ttl)
} }
for _, key := range keys { for _, key := range keys {
_, _ = d.Get(key) _, _ = d.Get(key)
} }
for _, key := range keys { for _, key := range keys {
d.Delete(key) _ = d.Delete(key)
} }
} }
}) })

View File

@ -156,7 +156,6 @@ func (e *Engine) Load() error {
name = strings.TrimSuffix(name, e.extension) name = strings.TrimSuffix(name, e.extension)
// name = strings.Replace(name, e.extension, "", -1) // name = strings.Replace(name, e.extension, "", -1)
// Read the file // Read the file
// #gosec G304
buf, err := utils.ReadFile(path, e.fileSystem) buf, err := utils.ReadFile(path, e.fileSystem)
if err != nil { if err != nil {
return err return err

View File

@ -21,7 +21,6 @@ func Walk(fs http.FileSystem, root string, walkFn filepath.WalkFunc) error {
return walk(fs, root, info, walkFn) return walk(fs, root, info, walkFn)
} }
// #nosec G304
// ReadFile returns the raw content of a file // ReadFile returns the raw content of a file
func ReadFile(path string, fs http.FileSystem) ([]byte, error) { func ReadFile(path string, fs http.FileSystem) ([]byte, error) {
if fs != nil { if fs != nil {

View File

@ -9,6 +9,7 @@ import (
"crypto/x509" "crypto/x509"
"errors" "errors"
"fmt" "fmt"
"log"
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
@ -31,7 +32,7 @@ func (app *App) Listener(ln net.Listener) error {
// Print startup message // Print startup message
if !app.config.DisableStartupMessage { if !app.config.DisableStartupMessage {
app.startupMessage(ln.Addr().String(), getTlsConfig(ln) != nil, "") app.startupMessage(ln.Addr().String(), getTLSConfig(ln) != nil, "")
} }
// Print routes // Print routes
@ -41,7 +42,7 @@ func (app *App) Listener(ln net.Listener) error {
// Prefork is not supported for custom listeners // Prefork is not supported for custom listeners
if app.config.Prefork { if app.config.Prefork {
fmt.Println("[Warning] Prefork isn't supported for custom listeners.") log.Printf("[Warning] Prefork isn't supported for custom listeners.\n")
} }
// Start listening // Start listening
@ -61,7 +62,7 @@ func (app *App) Listen(addr string) error {
// Setup listener // Setup listener
ln, err := net.Listen(app.config.Network, addr) ln, err := net.Listen(app.config.Network, addr)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to listen: %w", err)
} }
// prepare the server for the start // prepare the server for the start
@ -94,7 +95,7 @@ func (app *App) ListenTLS(addr, certFile, keyFile string) error {
// Set TLS config with handler // Set TLS config with handler
cert, err := tls.LoadX509KeyPair(certFile, keyFile) cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil { if err != nil {
return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %s", certFile, keyFile, err) return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %w", certFile, keyFile, err)
} }
tlsHandler := &TLSHandler{} tlsHandler := &TLSHandler{}
@ -115,7 +116,7 @@ func (app *App) ListenTLS(addr, certFile, keyFile string) error {
ln, err := net.Listen(app.config.Network, addr) ln, err := net.Listen(app.config.Network, addr)
ln = tls.NewListener(ln, config) ln = tls.NewListener(ln, config)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to listen: %w", err)
} }
// prepare the server for the start // prepare the server for the start
@ -150,12 +151,12 @@ func (app *App) ListenMutualTLS(addr, certFile, keyFile, clientCertFile string)
cert, err := tls.LoadX509KeyPair(certFile, keyFile) cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil { if err != nil {
return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %s", certFile, keyFile, err) return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %w", certFile, keyFile, err)
} }
clientCACert, err := os.ReadFile(filepath.Clean(clientCertFile)) clientCACert, err := os.ReadFile(filepath.Clean(clientCertFile))
if err != nil { if err != nil {
return err return fmt.Errorf("failed to read file: %w", err)
} }
clientCertPool := x509.NewCertPool() clientCertPool := x509.NewCertPool()
clientCertPool.AppendCertsFromPEM(clientCACert) clientCertPool.AppendCertsFromPEM(clientCACert)
@ -179,7 +180,7 @@ func (app *App) ListenMutualTLS(addr, certFile, keyFile, clientCertFile string)
// Setup listener // Setup listener
ln, err := tls.Listen(app.config.Network, addr, config) ln, err := tls.Listen(app.config.Network, addr, config)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to listen: %w", err)
} }
// prepare the server for the start // prepare the server for the start
@ -203,7 +204,7 @@ func (app *App) ListenMutualTLS(addr, certFile, keyFile, clientCertFile string)
} }
// startupMessage prepares the startup message with the handler number, port, address and other information // startupMessage prepares the startup message with the handler number, port, address and other information
func (app *App) startupMessage(addr string, tls bool, pids string) { func (app *App) startupMessage(addr string, tls bool, pids string) { //nolint: revive // Accepting a bool param is fine here
// ignore child processes // ignore child processes
if IsChild() { if IsChild() {
return return
@ -227,7 +228,8 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
} }
center := func(s string, width int) string { center := func(s string, width int) string {
pad := strconv.Itoa((width - len(s)) / 2) const padDiv = 2
pad := strconv.Itoa((width - len(s)) / padDiv)
str := fmt.Sprintf("%"+pad+"s", " ") str := fmt.Sprintf("%"+pad+"s", " ")
str += s str += s
str += fmt.Sprintf("%"+pad+"s", " ") str += fmt.Sprintf("%"+pad+"s", " ")
@ -238,7 +240,8 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
} }
centerValue := func(s string, width int) string { centerValue := func(s string, width int) string {
pad := strconv.Itoa((width - runewidth.StringWidth(s)) / 2) const padDiv = 2
pad := strconv.Itoa((width - runewidth.StringWidth(s)) / padDiv)
str := fmt.Sprintf("%"+pad+"s", " ") str := fmt.Sprintf("%"+pad+"s", " ")
str += fmt.Sprintf("%s%s%s", colors.Cyan, s, colors.Black) str += fmt.Sprintf("%s%s%s", colors.Cyan, s, colors.Black)
str += fmt.Sprintf("%"+pad+"s", " ") str += fmt.Sprintf("%"+pad+"s", " ")
@ -249,13 +252,13 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
return str return str
} }
pad := func(s string, width int) (str string) { pad := func(s string, width int) string {
toAdd := width - len(s) toAdd := width - len(s)
str += s str := s
for i := 0; i < toAdd; i++ { for i := 0; i < toAdd; i++ {
str += " " str += " "
} }
return return str
} }
host, port := parseAddr(addr) host, port := parseAddr(addr)
@ -267,9 +270,9 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
} }
} }
scheme := "http" scheme := schemeHTTP
if tls { if tls {
scheme = "https" scheme = schemeHTTPS
} }
isPrefork := "Disabled" isPrefork := "Disabled"
@ -282,19 +285,18 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
procs = "1" procs = "1"
} }
const lineLen = 49
mainLogo := colors.Black + " ┌───────────────────────────────────────────────────┐\n" mainLogo := colors.Black + " ┌───────────────────────────────────────────────────┐\n"
if app.config.AppName != "" { if app.config.AppName != "" {
mainLogo += " │ " + centerValue(app.config.AppName, 49) + " │\n" mainLogo += " │ " + centerValue(app.config.AppName, lineLen) + " │\n"
} }
mainLogo += " │ " + centerValue("Fiber v"+Version, 49) + " │\n" mainLogo += " │ " + centerValue("Fiber v"+Version, lineLen) + " │\n"
if host == "0.0.0.0" { if host == "0.0.0.0" {
mainLogo += mainLogo += " │ " + center(fmt.Sprintf("%s://127.0.0.1:%s", scheme, port), lineLen) + " │\n" +
" │ " + center(fmt.Sprintf("%s://127.0.0.1:%s", scheme, port), 49) + " │\n" + " │ " + center(fmt.Sprintf("(bound on host 0.0.0.0 and port %s)", port), lineLen) + " │\n"
" │ " + center(fmt.Sprintf("(bound on host 0.0.0.0 and port %s)", port), 49) + " │\n"
} else { } else {
mainLogo += mainLogo += " │ " + center(fmt.Sprintf("%s://%s:%s", scheme, host, port), lineLen) + " │\n"
" │ " + center(fmt.Sprintf("%s://%s:%s", scheme, host, port), 49) + " │\n"
} }
mainLogo += fmt.Sprintf( mainLogo += fmt.Sprintf(
@ -303,8 +305,8 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
" │ Prefork .%s PID ....%s │\n"+ " │ Prefork .%s PID ....%s │\n"+
" └───────────────────────────────────────────────────┘"+ " └───────────────────────────────────────────────────┘"+
colors.Reset, colors.Reset,
value(strconv.Itoa(int(app.handlersCount)), 14), value(procs, 12), value(strconv.Itoa(int(app.handlersCount)), 14), value(procs, 12), //nolint:gomnd // Using random padding lengths is fine here
value(isPrefork, 14), value(strconv.Itoa(os.Getpid()), 14), value(isPrefork, 14), value(strconv.Itoa(os.Getpid()), 14), //nolint:gomnd // Using random padding lengths is fine here
) )
var childPidsLogo string var childPidsLogo string
@ -329,19 +331,21 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
thisLine := "Child PIDs ... " thisLine := "Child PIDs ... "
var itemsOnThisLine []string var itemsOnThisLine []string
const maxLineLen = 49
addLine := func() { addLine := func() {
lines = append(lines, lines = append(lines,
fmt.Sprintf( fmt.Sprintf(
newLine, newLine,
colors.Black, colors.Black,
thisLine+colors.Cyan+pad(strings.Join(itemsOnThisLine, ", "), 49-len(thisLine)), thisLine+colors.Cyan+pad(strings.Join(itemsOnThisLine, ", "), maxLineLen-len(thisLine)),
colors.Black, colors.Black,
), ),
) )
} }
for _, pid := range pidSlice { for _, pid := range pidSlice {
if len(thisLine+strings.Join(append(itemsOnThisLine, pid), ", ")) > 49 { if len(thisLine+strings.Join(append(itemsOnThisLine, pid), ", ")) > maxLineLen {
addLine() addLine()
thisLine = "" thisLine = ""
itemsOnThisLine = []string{pid} itemsOnThisLine = []string{pid}
@ -415,7 +419,7 @@ func (app *App) printRoutesMessage() {
var routes []RouteMessage var routes []RouteMessage
for _, routeStack := range app.stack { for _, routeStack := range app.stack {
for _, route := range routeStack { for _, route := range routeStack {
var newRoute = RouteMessage{} var newRoute RouteMessage
newRoute.name = route.Name newRoute.name = route.Name
newRoute.method = route.Method newRoute.method = route.Method
newRoute.path = route.Path newRoute.path = route.Path
@ -443,5 +447,5 @@ func (app *App) printRoutesMessage() {
_, _ = fmt.Fprintf(w, "%s%s\t%s| %s%s\t%s| %s%s\t%s| %s%s\n", colors.Blue, route.method, colors.White, colors.Green, route.path, colors.White, colors.Cyan, route.name, colors.White, colors.Yellow, route.handlers) _, _ = fmt.Fprintf(w, "%s%s\t%s| %s%s\t%s| %s%s\t%s| %s%s\n", colors.Blue, route.method, colors.White, colors.Green, route.path, colors.White, colors.Cyan, route.name, colors.White, colors.Yellow, route.handlers)
} }
_ = w.Flush() _ = w.Flush() //nolint:errcheck // It is fine to ignore the error here
} }

View File

@ -7,7 +7,6 @@ package fiber
import ( import (
"bytes" "bytes"
"crypto/tls" "crypto/tls"
"fmt"
"io" "io"
"log" "log"
"os" "os"
@ -17,6 +16,7 @@ import (
"time" "time"
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp/fasthttputil" "github.com/valyala/fasthttp/fasthttputil"
) )
@ -125,8 +125,10 @@ func Test_App_Listener_TLS_Listener(t *testing.T) {
if err != nil { if err != nil {
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
} }
//nolint:gosec // We're in a test so using old ciphers is fine
config := &tls.Config{Certificates: []tls.Certificate{cer}} config := &tls.Config{Certificates: []tls.Certificate{cer}}
//nolint:gosec // We're in a test so listening on all interfaces is fine
ln, err := tls.Listen(NetworkTCP4, ":0", config) ln, err := tls.Listen(NetworkTCP4, ":0", config)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
@ -182,7 +184,6 @@ func Test_App_Master_Process_Show_Startup_Message(t *testing.T) {
New(Config{Prefork: true}). New(Config{Prefork: true}).
startupMessage(":3000", true, strings.Repeat(",11111,22222,33333,44444,55555,60000", 10)) startupMessage(":3000", true, strings.Repeat(",11111,22222,33333,44444,55555,60000", 10))
}) })
fmt.Println(startupMessage)
utils.AssertEqual(t, true, strings.Contains(startupMessage, "https://127.0.0.1:3000")) utils.AssertEqual(t, true, strings.Contains(startupMessage, "https://127.0.0.1:3000"))
utils.AssertEqual(t, true, strings.Contains(startupMessage, "(bound on host 0.0.0.0 and port 3000)")) utils.AssertEqual(t, true, strings.Contains(startupMessage, "(bound on host 0.0.0.0 and port 3000)"))
utils.AssertEqual(t, true, strings.Contains(startupMessage, "Child PIDs")) utils.AssertEqual(t, true, strings.Contains(startupMessage, "Child PIDs"))
@ -196,7 +197,6 @@ func Test_App_Master_Process_Show_Startup_MessageWithAppName(t *testing.T) {
startupMessage := captureOutput(func() { startupMessage := captureOutput(func() {
app.startupMessage(":3000", true, strings.Repeat(",11111,22222,33333,44444,55555,60000", 10)) app.startupMessage(":3000", true, strings.Repeat(",11111,22222,33333,44444,55555,60000", 10))
}) })
fmt.Println(startupMessage)
utils.AssertEqual(t, "Test App v1.0.1", app.Config().AppName) utils.AssertEqual(t, "Test App v1.0.1", app.Config().AppName)
utils.AssertEqual(t, true, strings.Contains(startupMessage, app.Config().AppName)) utils.AssertEqual(t, true, strings.Contains(startupMessage, app.Config().AppName))
} }
@ -208,7 +208,6 @@ func Test_App_Master_Process_Show_Startup_MessageWithAppNameNonAscii(t *testing.
startupMessage := captureOutput(func() { startupMessage := captureOutput(func() {
app.startupMessage(":3000", false, "") app.startupMessage(":3000", false, "")
}) })
fmt.Println(startupMessage)
utils.AssertEqual(t, true, strings.Contains(startupMessage, "│ Serveur de vérification des données │")) utils.AssertEqual(t, true, strings.Contains(startupMessage, "│ Serveur de vérification des données │"))
} }
@ -219,8 +218,7 @@ func Test_App_print_Route(t *testing.T) {
printRoutesMessage := captureOutput(func() { printRoutesMessage := captureOutput(func() {
app.printRoutesMessage() app.printRoutesMessage()
}) })
fmt.Println(printRoutesMessage) utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, MethodGet))
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "GET"))
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "/")) utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "/"))
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "emptyHandler")) utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "emptyHandler"))
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "routeName")) utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "routeName"))
@ -240,11 +238,11 @@ func Test_App_print_Route_with_group(t *testing.T) {
app.printRoutesMessage() app.printRoutesMessage()
}) })
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "GET")) utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, MethodGet))
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "/")) utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "/"))
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "emptyHandler")) utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "emptyHandler"))
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "/v1/test")) utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "/v1/test"))
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "POST")) utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, MethodPost))
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "/v1/test/fiber")) utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "/v1/test/fiber"))
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "PUT")) utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "PUT"))
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "/v1/test/fiber/*")) utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "/v1/test/fiber/*"))

View File

@ -1,15 +1,15 @@
package basicauth package basicauth
import ( import (
"encoding/base64"
"fmt" "fmt"
"io" "io"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
b64 "encoding/base64"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -23,7 +23,7 @@ func Test_BasicAuth_Next(t *testing.T) {
}, },
})) }))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode) utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
} }
@ -39,6 +39,7 @@ func Test_Middleware_BasicAuth(t *testing.T) {
}, },
})) }))
//nolint:forcetypeassert,errcheck // TODO: Do not force-type assert
app.Get("/testauth", func(c *fiber.Ctx) error { app.Get("/testauth", func(c *fiber.Ctx) error {
username := c.Locals("username").(string) username := c.Locals("username").(string)
password := c.Locals("password").(string) password := c.Locals("password").(string)
@ -74,9 +75,9 @@ func Test_Middleware_BasicAuth(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
// Base64 encode credentials for http auth header // Base64 encode credentials for http auth header
creds := b64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", tt.username, tt.password))) creds := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", tt.username, tt.password)))
req := httptest.NewRequest("GET", "/testauth", nil) req := httptest.NewRequest(fiber.MethodGet, "/testauth", nil)
req.Header.Add("Authorization", "Basic "+creds) req.Header.Add("Authorization", "Basic "+creds)
resp, err := app.Test(req) resp, err := app.Test(req)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
@ -108,7 +109,7 @@ func Benchmark_Middleware_BasicAuth(b *testing.B) {
h := app.Handler() h := app.Handler()
fctx := &fasthttp.RequestCtx{} fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod("GET") fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/") fctx.Request.SetRequestURI("/")
fctx.Request.Header.Set(fiber.HeaderAuthorization, "basic am9objpkb2U=") // john:doe fctx.Request.Header.Set(fiber.HeaderAuthorization, "basic am9objpkb2U=") // john:doe

View File

@ -53,6 +53,8 @@ type Config struct {
} }
// ConfigDefault is the default config // ConfigDefault is the default config
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{ var ConfigDefault = Config{
Next: nil, Next: nil,
Users: map[string]string{}, Users: map[string]string{},

View File

@ -34,6 +34,7 @@ const (
noStore = "no-store" noStore = "no-store"
) )
//nolint:gochecknoglobals // TODO: Do not use a global var here
var ignoreHeaders = map[string]interface{}{ var ignoreHeaders = map[string]interface{}{
"Connection": nil, "Connection": nil,
"Keep-Alive": nil, "Keep-Alive": nil,
@ -43,8 +44,8 @@ var ignoreHeaders = map[string]interface{}{
"Trailers": nil, "Trailers": nil,
"Transfer-Encoding": nil, "Transfer-Encoding": nil,
"Upgrade": nil, "Upgrade": nil,
"Content-Type": nil, // already stored explicitely by the cache manager "Content-Type": nil, // already stored explicitly by the cache manager
"Content-Encoding": nil, // already stored explicitely by the cache manager "Content-Encoding": nil, // already stored explicitly by the cache manager
} }
// New creates a new middleware handler // New creates a new middleware handler
@ -69,7 +70,7 @@ func New(config ...Config) fiber.Handler {
// Create indexed heap for tracking expirations ( see heap.go ) // Create indexed heap for tracking expirations ( see heap.go )
heap := &indexedHeap{} heap := &indexedHeap{}
// count stored bytes (sizes of response bodies) // count stored bytes (sizes of response bodies)
var storedBytes uint = 0 var storedBytes uint
// Update timestamp in the configured interval // Update timestamp in the configured interval
go func() { go func() {
@ -81,10 +82,10 @@ func New(config ...Config) fiber.Handler {
// Delete key from both manager and storage // Delete key from both manager and storage
deleteKey := func(dkey string) { deleteKey := func(dkey string) {
manager.delete(dkey) manager.del(dkey)
// External storage saves body data with different key // External storage saves body data with different key
if cfg.Storage != nil { if cfg.Storage != nil {
manager.delete(dkey + "_body") manager.del(dkey + "_body")
} }
} }
@ -205,7 +206,7 @@ func New(config ...Config) fiber.Handler {
if cfg.StoreResponseHeaders { if cfg.StoreResponseHeaders {
e.headers = make(map[string][]byte) e.headers = make(map[string][]byte)
c.Response().Header.VisitAll( c.Response().Header.VisitAll(
func(key []byte, value []byte) { func(key, value []byte) {
// create real copy // create real copy
keyS := string(key) keyS := string(key)
if _, ok := ignoreHeaders[keyS]; !ok { if _, ok := ignoreHeaders[keyS]; !ok {

View File

@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"io" "io"
"math" "math"
"net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
"strconv" "strconv"
@ -18,6 +17,7 @@ import (
"github.com/gofiber/fiber/v2/internal/storage/memory" "github.com/gofiber/fiber/v2/internal/storage/memory"
"github.com/gofiber/fiber/v2/middleware/etag" "github.com/gofiber/fiber/v2/middleware/etag"
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -35,10 +35,10 @@ func Test_Cache_CacheControl(t *testing.T) {
return c.SendString("Hello, World!") return c.SendString("Hello, World!")
}) })
_, err := app.Test(httptest.NewRequest("GET", "/", nil)) _, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "public, max-age=10", resp.Header.Get(fiber.HeaderCacheControl)) utils.AssertEqual(t, "public, max-age=10", resp.Header.Get(fiber.HeaderCacheControl))
} }
@ -53,7 +53,7 @@ func Test_Cache_Expired(t *testing.T) {
return c.SendString(fmt.Sprintf("%d", time.Now().UnixNano())) return c.SendString(fmt.Sprintf("%d", time.Now().UnixNano()))
}) })
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
@ -61,7 +61,7 @@ func Test_Cache_Expired(t *testing.T) {
// Sleep until the cache is expired // Sleep until the cache is expired
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)
respCached, err := app.Test(httptest.NewRequest("GET", "/", nil)) respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
bodyCached, err := io.ReadAll(respCached.Body) bodyCached, err := io.ReadAll(respCached.Body)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
@ -71,7 +71,7 @@ func Test_Cache_Expired(t *testing.T) {
} }
// Next response should be also cached // Next response should be also cached
respCachedNextRound, err := app.Test(httptest.NewRequest("GET", "/", nil)) respCachedNextRound, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
bodyCachedNextRound, err := io.ReadAll(respCachedNextRound.Body) bodyCachedNextRound, err := io.ReadAll(respCachedNextRound.Body)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
@ -92,11 +92,11 @@ func Test_Cache(t *testing.T) {
return c.SendString(now) return c.SendString(now)
}) })
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(fiber.MethodGet, "/", nil)
resp, err := app.Test(req) resp, err := app.Test(req)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
cachedReq := httptest.NewRequest("GET", "/", nil) cachedReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
cachedResp, err := app.Test(cachedReq) cachedResp, err := app.Test(cachedReq)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
@ -120,31 +120,31 @@ func Test_Cache_WithNoCacheRequestDirective(t *testing.T) {
}) })
// Request id = 1 // Request id = 1
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(fiber.MethodGet, "/", nil)
resp, err := app.Test(req) resp, err := app.Test(req)
defer resp.Body.Close() utils.AssertEqual(t, nil, err)
body, _ := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache")) utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache"))
utils.AssertEqual(t, []byte("1"), body) utils.AssertEqual(t, []byte("1"), body)
// Response cached, entry id = 1 // Response cached, entry id = 1
// Request id = 2 without Cache-Control: no-cache // Request id = 2 without Cache-Control: no-cache
cachedReq := httptest.NewRequest("GET", "/?id=2", nil) cachedReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
cachedResp, err := app.Test(cachedReq) cachedResp, err := app.Test(cachedReq)
defer cachedResp.Body.Close() utils.AssertEqual(t, nil, err)
cachedBody, _ := io.ReadAll(cachedResp.Body) cachedBody, err := io.ReadAll(cachedResp.Body)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheHit, cachedResp.Header.Get("X-Cache")) utils.AssertEqual(t, cacheHit, cachedResp.Header.Get("X-Cache"))
utils.AssertEqual(t, []byte("1"), cachedBody) utils.AssertEqual(t, []byte("1"), cachedBody)
// Response not cached, returns cached response, entry id = 1 // Response not cached, returns cached response, entry id = 1
// Request id = 2 with Cache-Control: no-cache // Request id = 2 with Cache-Control: no-cache
noCacheReq := httptest.NewRequest("GET", "/?id=2", nil) noCacheReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
noCacheReq.Header.Set(fiber.HeaderCacheControl, noCache) noCacheReq.Header.Set(fiber.HeaderCacheControl, noCache)
noCacheResp, err := app.Test(noCacheReq) noCacheResp, err := app.Test(noCacheReq)
defer noCacheResp.Body.Close() utils.AssertEqual(t, nil, err)
noCacheBody, _ := io.ReadAll(noCacheResp.Body) noCacheBody, err := io.ReadAll(noCacheResp.Body)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheMiss, noCacheResp.Header.Get("X-Cache")) utils.AssertEqual(t, cacheMiss, noCacheResp.Header.Get("X-Cache"))
utils.AssertEqual(t, []byte("2"), noCacheBody) utils.AssertEqual(t, []byte("2"), noCacheBody)
@ -152,21 +152,21 @@ func Test_Cache_WithNoCacheRequestDirective(t *testing.T) {
/* Check Test_Cache_WithETagAndNoCacheRequestDirective */ /* Check Test_Cache_WithETagAndNoCacheRequestDirective */
// Request id = 2 with Cache-Control: no-cache again // Request id = 2 with Cache-Control: no-cache again
noCacheReq1 := httptest.NewRequest("GET", "/?id=2", nil) noCacheReq1 := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
noCacheReq1.Header.Set(fiber.HeaderCacheControl, noCache) noCacheReq1.Header.Set(fiber.HeaderCacheControl, noCache)
noCacheResp1, err := app.Test(noCacheReq1) noCacheResp1, err := app.Test(noCacheReq1)
defer noCacheResp1.Body.Close() utils.AssertEqual(t, nil, err)
noCacheBody1, _ := io.ReadAll(noCacheResp1.Body) noCacheBody1, err := io.ReadAll(noCacheResp1.Body)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheMiss, noCacheResp1.Header.Get("X-Cache")) utils.AssertEqual(t, cacheMiss, noCacheResp1.Header.Get("X-Cache"))
utils.AssertEqual(t, []byte("2"), noCacheBody1) utils.AssertEqual(t, []byte("2"), noCacheBody1)
// Response cached, returns updated response, entry = 2 // Response cached, returns updated response, entry = 2
// Request id = 1 without Cache-Control: no-cache // Request id = 1 without Cache-Control: no-cache
cachedReq1 := httptest.NewRequest("GET", "/", nil) cachedReq1 := httptest.NewRequest(fiber.MethodGet, "/", nil)
cachedResp1, err := app.Test(cachedReq1) cachedResp1, err := app.Test(cachedReq1)
defer cachedResp1.Body.Close() utils.AssertEqual(t, nil, err)
cachedBody1, _ := io.ReadAll(cachedResp1.Body) cachedBody1, err := io.ReadAll(cachedResp1.Body)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheHit, cachedResp1.Header.Get("X-Cache")) utils.AssertEqual(t, cacheHit, cachedResp1.Header.Get("X-Cache"))
utils.AssertEqual(t, []byte("2"), cachedBody1) utils.AssertEqual(t, []byte("2"), cachedBody1)
@ -188,7 +188,7 @@ func Test_Cache_WithETagAndNoCacheRequestDirective(t *testing.T) {
}) })
// Request id = 1 // Request id = 1
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(fiber.MethodGet, "/", nil)
resp, err := app.Test(req) resp, err := app.Test(req)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache")) utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache"))
@ -199,7 +199,7 @@ func Test_Cache_WithETagAndNoCacheRequestDirective(t *testing.T) {
etagToken := resp.Header.Get("Etag") etagToken := resp.Header.Get("Etag")
// Request id = 2 with ETag but without Cache-Control: no-cache // Request id = 2 with ETag but without Cache-Control: no-cache
cachedReq := httptest.NewRequest("GET", "/?id=2", nil) cachedReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
cachedReq.Header.Set(fiber.HeaderIfNoneMatch, etagToken) cachedReq.Header.Set(fiber.HeaderIfNoneMatch, etagToken)
cachedResp, err := app.Test(cachedReq) cachedResp, err := app.Test(cachedReq)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
@ -208,7 +208,7 @@ func Test_Cache_WithETagAndNoCacheRequestDirective(t *testing.T) {
// Response not cached, returns cached response, entry id = 1, status not modified // Response not cached, returns cached response, entry id = 1, status not modified
// Request id = 2 with ETag and Cache-Control: no-cache // Request id = 2 with ETag and Cache-Control: no-cache
noCacheReq := httptest.NewRequest("GET", "/?id=2", nil) noCacheReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
noCacheReq.Header.Set(fiber.HeaderCacheControl, noCache) noCacheReq.Header.Set(fiber.HeaderCacheControl, noCache)
noCacheReq.Header.Set(fiber.HeaderIfNoneMatch, etagToken) noCacheReq.Header.Set(fiber.HeaderIfNoneMatch, etagToken)
noCacheResp, err := app.Test(noCacheReq) noCacheResp, err := app.Test(noCacheReq)
@ -221,7 +221,7 @@ func Test_Cache_WithETagAndNoCacheRequestDirective(t *testing.T) {
etagToken = noCacheResp.Header.Get("Etag") etagToken = noCacheResp.Header.Get("Etag")
// Request id = 2 with ETag and Cache-Control: no-cache again // Request id = 2 with ETag and Cache-Control: no-cache again
noCacheReq1 := httptest.NewRequest("GET", "/?id=2", nil) noCacheReq1 := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
noCacheReq1.Header.Set(fiber.HeaderCacheControl, noCache) noCacheReq1.Header.Set(fiber.HeaderCacheControl, noCache)
noCacheReq1.Header.Set(fiber.HeaderIfNoneMatch, etagToken) noCacheReq1.Header.Set(fiber.HeaderIfNoneMatch, etagToken)
noCacheResp1, err := app.Test(noCacheReq1) noCacheResp1, err := app.Test(noCacheReq1)
@ -231,7 +231,7 @@ func Test_Cache_WithETagAndNoCacheRequestDirective(t *testing.T) {
// Response cached, returns updated response, entry id = 2, status not modified // Response cached, returns updated response, entry id = 2, status not modified
// Request id = 1 without ETag and Cache-Control: no-cache // Request id = 1 without ETag and Cache-Control: no-cache
cachedReq1 := httptest.NewRequest("GET", "/", nil) cachedReq1 := httptest.NewRequest(fiber.MethodGet, "/", nil)
cachedResp1, err := app.Test(cachedReq1) cachedResp1, err := app.Test(cachedReq1)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheHit, cachedResp1.Header.Get("X-Cache")) utils.AssertEqual(t, cacheHit, cachedResp1.Header.Get("X-Cache"))
@ -251,11 +251,11 @@ func Test_Cache_WithNoStoreRequestDirective(t *testing.T) {
}) })
// Request id = 2 // Request id = 2
noStoreReq := httptest.NewRequest("GET", "/?id=2", nil) noStoreReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
noStoreReq.Header.Set(fiber.HeaderCacheControl, noStore) noStoreReq.Header.Set(fiber.HeaderCacheControl, noStore)
noStoreResp, err := app.Test(noStoreReq) noStoreResp, err := app.Test(noStoreReq)
defer noStoreResp.Body.Close() utils.AssertEqual(t, nil, err)
noStoreBody, _ := io.ReadAll(noStoreResp.Body) noStoreBody, err := io.ReadAll(noStoreResp.Body)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, []byte("2"), noStoreBody) utils.AssertEqual(t, []byte("2"), noStoreBody)
// Response not cached, returns updated response // Response not cached, returns updated response
@ -278,11 +278,11 @@ func Test_Cache_WithSeveralRequests(t *testing.T) {
for runs := 0; runs < 10; runs++ { for runs := 0; runs < 10; runs++ {
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
func(id int) { func(id int) {
rsp, err := app.Test(httptest.NewRequest(http.MethodGet, fmt.Sprintf("/%d", id), nil)) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, fmt.Sprintf("/%d", id), nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
defer func(Body io.ReadCloser) { defer func(body io.ReadCloser) {
err := Body.Close() err := body.Close()
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
}(rsp.Body) }(rsp.Body)
@ -311,11 +311,11 @@ func Test_Cache_Invalid_Expiration(t *testing.T) {
return c.SendString(now) return c.SendString(now)
}) })
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(fiber.MethodGet, "/", nil)
resp, err := app.Test(req) resp, err := app.Test(req)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
cachedReq := httptest.NewRequest("GET", "/", nil) cachedReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
cachedResp, err := app.Test(cachedReq) cachedResp, err := app.Test(cachedReq)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
@ -342,25 +342,25 @@ func Test_Cache_Get(t *testing.T) {
return c.SendString(c.Query("cache")) return c.SendString(c.Query("cache"))
}) })
resp, err := app.Test(httptest.NewRequest("POST", "/?cache=123", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=123", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "123", string(body)) utils.AssertEqual(t, "123", string(body))
resp, err = app.Test(httptest.NewRequest("POST", "/?cache=12345", nil)) resp, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=12345", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
body, err = io.ReadAll(resp.Body) body, err = io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "12345", string(body)) utils.AssertEqual(t, "12345", string(body))
resp, err = app.Test(httptest.NewRequest("GET", "/get?cache=123", nil)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=123", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
body, err = io.ReadAll(resp.Body) body, err = io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "123", string(body)) utils.AssertEqual(t, "123", string(body))
resp, err = app.Test(httptest.NewRequest("GET", "/get?cache=12345", nil)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=12345", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
body, err = io.ReadAll(resp.Body) body, err = io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
@ -384,25 +384,25 @@ func Test_Cache_Post(t *testing.T) {
return c.SendString(c.Query("cache")) return c.SendString(c.Query("cache"))
}) })
resp, err := app.Test(httptest.NewRequest("POST", "/?cache=123", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=123", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "123", string(body)) utils.AssertEqual(t, "123", string(body))
resp, err = app.Test(httptest.NewRequest("POST", "/?cache=12345", nil)) resp, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=12345", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
body, err = io.ReadAll(resp.Body) body, err = io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "123", string(body)) utils.AssertEqual(t, "123", string(body))
resp, err = app.Test(httptest.NewRequest("GET", "/get?cache=123", nil)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=123", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
body, err = io.ReadAll(resp.Body) body, err = io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "123", string(body)) utils.AssertEqual(t, "123", string(body))
resp, err = app.Test(httptest.NewRequest("GET", "/get?cache=12345", nil)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=12345", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
body, err = io.ReadAll(resp.Body) body, err = io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
@ -420,14 +420,14 @@ func Test_Cache_NothingToCache(t *testing.T) {
return c.SendString(time.Now().String()) return c.SendString(time.Now().String())
}) })
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
respCached, err := app.Test(httptest.NewRequest("GET", "/", nil)) respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
bodyCached, err := io.ReadAll(respCached.Body) bodyCached, err := io.ReadAll(respCached.Body)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
@ -457,22 +457,22 @@ func Test_Cache_CustomNext(t *testing.T) {
return c.Status(fiber.StatusInternalServerError).SendString(time.Now().String()) return c.Status(fiber.StatusInternalServerError).SendString(time.Now().String())
}) })
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
respCached, err := app.Test(httptest.NewRequest("GET", "/", nil)) respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
bodyCached, err := io.ReadAll(respCached.Body) bodyCached, err := io.ReadAll(respCached.Body)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, bytes.Equal(body, bodyCached)) utils.AssertEqual(t, true, bytes.Equal(body, bodyCached))
utils.AssertEqual(t, true, respCached.Header.Get(fiber.HeaderCacheControl) != "") utils.AssertEqual(t, true, respCached.Header.Get(fiber.HeaderCacheControl) != "")
_, err = app.Test(httptest.NewRequest("GET", "/error", nil)) _, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/error", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
errRespCached, err := app.Test(httptest.NewRequest("GET", "/error", nil)) errRespCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/error", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, errRespCached.Header.Get(fiber.HeaderCacheControl) == "") utils.AssertEqual(t, true, errRespCached.Header.Get(fiber.HeaderCacheControl) == "")
} }
@ -491,7 +491,7 @@ func Test_CustomKey(t *testing.T) {
return c.SendString("hi") return c.SendString("hi")
}) })
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(fiber.MethodGet, "/", nil)
_, err := app.Test(req) _, err := app.Test(req)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, called) utils.AssertEqual(t, true, called)
@ -505,7 +505,9 @@ func Test_CustomExpiration(t *testing.T) {
var newCacheTime int var newCacheTime int
app.Use(New(Config{ExpirationGenerator: func(c *fiber.Ctx, cfg *Config) time.Duration { app.Use(New(Config{ExpirationGenerator: func(c *fiber.Ctx, cfg *Config) time.Duration {
called = true called = true
newCacheTime, _ = strconv.Atoi(c.GetRespHeader("Cache-Time", "600")) var err error
newCacheTime, err = strconv.Atoi(c.GetRespHeader("Cache-Time", "600"))
utils.AssertEqual(t, nil, err)
return time.Second * time.Duration(newCacheTime) return time.Second * time.Duration(newCacheTime)
}})) }}))
@ -515,7 +517,7 @@ func Test_CustomExpiration(t *testing.T) {
return c.SendString(now) return c.SendString(now)
}) })
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, called) utils.AssertEqual(t, true, called)
utils.AssertEqual(t, 1, newCacheTime) utils.AssertEqual(t, 1, newCacheTime)
@ -523,7 +525,7 @@ func Test_CustomExpiration(t *testing.T) {
// Sleep until the cache is expired // Sleep until the cache is expired
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
cachedResp, err := app.Test(httptest.NewRequest("GET", "/", nil)) cachedResp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
@ -536,7 +538,7 @@ func Test_CustomExpiration(t *testing.T) {
} }
// Next response should be cached // Next response should be cached
cachedRespNextRound, err := app.Test(httptest.NewRequest("GET", "/", nil)) cachedRespNextRound, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
cachedBodyNextRound, err := io.ReadAll(cachedRespNextRound.Body) cachedBodyNextRound, err := io.ReadAll(cachedRespNextRound.Body)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
@ -559,12 +561,12 @@ func Test_AdditionalE2EResponseHeaders(t *testing.T) {
return c.SendString("hi") return c.SendString("hi")
}) })
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(fiber.MethodGet, "/", nil)
resp, err := app.Test(req) resp, err := app.Test(req)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "foobar", resp.Header.Get("X-Foobar")) utils.AssertEqual(t, "foobar", resp.Header.Get("X-Foobar"))
req = httptest.NewRequest("GET", "/", nil) req = httptest.NewRequest(fiber.MethodGet, "/", nil)
resp, err = app.Test(req) resp, err = app.Test(req)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "foobar", resp.Header.Get("X-Foobar")) utils.AssertEqual(t, "foobar", resp.Header.Get("X-Foobar"))
@ -594,19 +596,19 @@ func Test_CacheHeader(t *testing.T) {
return c.Status(fiber.StatusInternalServerError).SendString(time.Now().String()) return c.Status(fiber.StatusInternalServerError).SendString(time.Now().String())
}) })
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache")) utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache"))
resp, err = app.Test(httptest.NewRequest("GET", "/", nil)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheHit, resp.Header.Get("X-Cache")) utils.AssertEqual(t, cacheHit, resp.Header.Get("X-Cache"))
resp, err = app.Test(httptest.NewRequest("POST", "/?cache=12345", nil)) resp, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=12345", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheUnreachable, resp.Header.Get("X-Cache")) utils.AssertEqual(t, cacheUnreachable, resp.Header.Get("X-Cache"))
errRespCached, err := app.Test(httptest.NewRequest("GET", "/error", nil)) errRespCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/error", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheUnreachable, errRespCached.Header.Get("X-Cache")) utils.AssertEqual(t, cacheUnreachable, errRespCached.Header.Get("X-Cache"))
} }
@ -622,12 +624,12 @@ func Test_Cache_WithHead(t *testing.T) {
return c.SendString(now) return c.SendString(now)
}) })
req := httptest.NewRequest("HEAD", "/", nil) req := httptest.NewRequest(fiber.MethodHead, "/", nil)
resp, err := app.Test(req) resp, err := app.Test(req)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache")) utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache"))
cachedReq := httptest.NewRequest("HEAD", "/", nil) cachedReq := httptest.NewRequest(fiber.MethodHead, "/", nil)
cachedResp, err := app.Test(cachedReq) cachedResp, err := app.Test(cachedReq)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheHit, cachedResp.Header.Get("X-Cache")) utils.AssertEqual(t, cacheHit, cachedResp.Header.Get("X-Cache"))
@ -649,28 +651,28 @@ func Test_Cache_WithHeadThenGet(t *testing.T) {
return c.SendString(c.Query("cache")) return c.SendString(c.Query("cache"))
}) })
headResp, err := app.Test(httptest.NewRequest("HEAD", "/?cache=123", nil)) headResp, err := app.Test(httptest.NewRequest(fiber.MethodHead, "/?cache=123", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
headBody, err := io.ReadAll(headResp.Body) headBody, err := io.ReadAll(headResp.Body)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "", string(headBody)) utils.AssertEqual(t, "", string(headBody))
utils.AssertEqual(t, cacheMiss, headResp.Header.Get("X-Cache")) utils.AssertEqual(t, cacheMiss, headResp.Header.Get("X-Cache"))
headResp, err = app.Test(httptest.NewRequest("HEAD", "/?cache=123", nil)) headResp, err = app.Test(httptest.NewRequest(fiber.MethodHead, "/?cache=123", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
headBody, err = io.ReadAll(headResp.Body) headBody, err = io.ReadAll(headResp.Body)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "", string(headBody)) utils.AssertEqual(t, "", string(headBody))
utils.AssertEqual(t, cacheHit, headResp.Header.Get("X-Cache")) utils.AssertEqual(t, cacheHit, headResp.Header.Get("X-Cache"))
getResp, err := app.Test(httptest.NewRequest("GET", "/?cache=123", nil)) getResp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?cache=123", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
getBody, err := io.ReadAll(getResp.Body) getBody, err := io.ReadAll(getResp.Body)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "123", string(getBody)) utils.AssertEqual(t, "123", string(getBody))
utils.AssertEqual(t, cacheMiss, getResp.Header.Get("X-Cache")) utils.AssertEqual(t, cacheMiss, getResp.Header.Get("X-Cache"))
getResp, err = app.Test(httptest.NewRequest("GET", "/?cache=123", nil)) getResp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/?cache=123", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
getBody, err = io.ReadAll(getResp.Body) getBody, err = io.ReadAll(getResp.Body)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
@ -691,7 +693,7 @@ func Test_CustomCacheHeader(t *testing.T) {
return c.SendString("Hello, World!") return c.SendString("Hello, World!")
}) })
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheMiss, resp.Header.Get("Cache-Status")) utils.AssertEqual(t, cacheMiss, resp.Header.Get("Cache-Status"))
} }
@ -702,7 +704,7 @@ func Test_CustomCacheHeader(t *testing.T) {
func stableAscendingExpiration() func(c1 *fiber.Ctx, c2 *Config) time.Duration { func stableAscendingExpiration() func(c1 *fiber.Ctx, c2 *Config) time.Duration {
i := 0 i := 0
return func(c1 *fiber.Ctx, c2 *Config) time.Duration { return func(c1 *fiber.Ctx, c2 *Config) time.Duration {
i += 1 i++
return time.Hour * time.Duration(i) return time.Hour * time.Duration(i)
} }
} }
@ -738,7 +740,7 @@ func Test_Cache_MaxBytesOrder(t *testing.T) {
} }
for idx, tcase := range cases { for idx, tcase := range cases {
rsp, err := app.Test(httptest.NewRequest("GET", tcase[0], nil)) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, tcase[0], nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, tcase[1], rsp.Header.Get("X-Cache"), fmt.Sprintf("Case %v", idx)) utils.AssertEqual(t, tcase[1], rsp.Header.Get("X-Cache"), fmt.Sprintf("Case %v", idx))
} }
@ -756,7 +758,8 @@ func Test_Cache_MaxBytesSizes(t *testing.T) {
app.Get("/*", func(c *fiber.Ctx) error { app.Get("/*", func(c *fiber.Ctx) error {
path := c.Context().URI().LastPathSegment() path := c.Context().URI().LastPathSegment()
size, _ := strconv.Atoi(string(path)) size, err := strconv.Atoi(string(path))
utils.AssertEqual(t, nil, err)
return c.Send(make([]byte, size)) return c.Send(make([]byte, size))
}) })
@ -772,7 +775,7 @@ func Test_Cache_MaxBytesSizes(t *testing.T) {
} }
for idx, tcase := range cases { for idx, tcase := range cases {
rsp, err := app.Test(httptest.NewRequest("GET", tcase[0], nil)) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, tcase[0], nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, tcase[1], rsp.Header.Get("X-Cache"), fmt.Sprintf("Case %v", idx)) utils.AssertEqual(t, tcase[1], rsp.Header.Get("X-Cache"), fmt.Sprintf("Case %v", idx))
} }
@ -785,14 +788,14 @@ func Benchmark_Cache(b *testing.B) {
app.Use(New()) app.Use(New())
app.Get("/demo", func(c *fiber.Ctx) error { app.Get("/demo", func(c *fiber.Ctx) error {
data, _ := os.ReadFile("../../.github/README.md") data, _ := os.ReadFile("../../.github/README.md") //nolint:errcheck // We're inside a benchmark
return c.Status(fiber.StatusTeapot).Send(data) return c.Status(fiber.StatusTeapot).Send(data)
}) })
h := app.Handler() h := app.Handler()
fctx := &fasthttp.RequestCtx{} fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod("GET") fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/demo") fctx.Request.SetRequestURI("/demo")
b.ReportAllocs() b.ReportAllocs()
@ -815,14 +818,14 @@ func Benchmark_Cache_Storage(b *testing.B) {
})) }))
app.Get("/demo", func(c *fiber.Ctx) error { app.Get("/demo", func(c *fiber.Ctx) error {
data, _ := os.ReadFile("../../.github/README.md") data, _ := os.ReadFile("../../.github/README.md") //nolint:errcheck // We're inside a benchmark
return c.Status(fiber.StatusTeapot).Send(data) return c.Status(fiber.StatusTeapot).Send(data)
}) })
h := app.Handler() h := app.Handler()
fctx := &fasthttp.RequestCtx{} fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod("GET") fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/demo") fctx.Request.SetRequestURI("/demo")
b.ReportAllocs() b.ReportAllocs()
@ -850,7 +853,7 @@ func Benchmark_Cache_AdditionalHeaders(b *testing.B) {
h := app.Handler() h := app.Handler()
fctx := &fasthttp.RequestCtx{} fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod("GET") fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/demo") fctx.Request.SetRequestURI("/demo")
b.ReportAllocs() b.ReportAllocs()
@ -882,7 +885,7 @@ func Benchmark_Cache_MaxSize(b *testing.B) {
h := app.Handler() h := app.Handler()
fctx := &fasthttp.RequestCtx{} fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod("GET") fctx.Request.Header.SetMethod(fiber.MethodGet)
b.ReportAllocs() b.ReportAllocs()
b.ResetTimer() b.ResetTimer()

View File

@ -1,7 +1,7 @@
package cache package cache
import ( import (
"fmt" "log"
"time" "time"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
@ -49,10 +49,10 @@ type Config struct {
// Default: an in memory store for this process only // Default: an in memory store for this process only
Storage fiber.Storage Storage fiber.Storage
// Deprecated, use Storage instead // Deprecated: Use Storage instead
Store fiber.Storage Store fiber.Storage
// Deprecated, use KeyGenerator instead // Deprecated: Use KeyGenerator instead
Key func(*fiber.Ctx) string Key func(*fiber.Ctx) string
// allows you to store additional headers generated by next middlewares & handler // allows you to store additional headers generated by next middlewares & handler
@ -75,6 +75,8 @@ type Config struct {
} }
// ConfigDefault is the default config // ConfigDefault is the default config
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{ var ConfigDefault = Config{
Next: nil, Next: nil,
Expiration: 1 * time.Minute, Expiration: 1 * time.Minute,
@ -102,11 +104,11 @@ func configDefault(config ...Config) Config {
// Set default values // Set default values
if cfg.Store != nil { if cfg.Store != nil {
fmt.Println("[CACHE] Store is deprecated, please use Storage") log.Printf("[CACHE] Store is deprecated, please use Storage\n")
cfg.Storage = cfg.Store cfg.Storage = cfg.Store
} }
if cfg.Key != nil { if cfg.Key != nil {
fmt.Println("[CACHE] Key is deprecated, please use KeyGenerator") log.Printf("[CACHE] Key is deprecated, please use KeyGenerator\n")
cfg.KeyGenerator = cfg.Key cfg.KeyGenerator = cfg.Key
} }
if cfg.Next == nil { if cfg.Next == nil {

View File

@ -41,7 +41,7 @@ func (h indexedHeap) Swap(i, j int) {
} }
func (h *indexedHeap) Push(x interface{}) { func (h *indexedHeap) Push(x interface{}) {
h.pushInternal(x.(heapEntry)) h.pushInternal(x.(heapEntry)) //nolint:forcetypeassert // Forced type assertion required to implement the heap.Interface interface
} }
func (h *indexedHeap) Pop() interface{} { func (h *indexedHeap) Pop() interface{} {
@ -65,7 +65,7 @@ func (h *indexedHeap) put(key string, exp uint64, bytes uint) int {
idx = h.entries[:n+1][n].idx idx = h.entries[:n+1][n].idx
} else { } else {
idx = h.maxidx idx = h.maxidx
h.maxidx += 1 h.maxidx++
h.indices = append(h.indices, idx) h.indices = append(h.indices, idx)
} }
// Push manually to avoid allocation // Push manually to avoid allocation
@ -77,7 +77,7 @@ func (h *indexedHeap) put(key string, exp uint64, bytes uint) int {
} }
func (h *indexedHeap) removeInternal(realIdx int) (string, uint) { func (h *indexedHeap) removeInternal(realIdx int) (string, uint) {
x := heap.Remove(h, realIdx).(heapEntry) x := heap.Remove(h, realIdx).(heapEntry) //nolint:forcetypeassert,errcheck // Forced type assertion required to implement the heap.Interface interface
return x.key, x.bytes return x.key, x.bytes
} }

View File

@ -51,7 +51,7 @@ func newManager(storage fiber.Storage) *manager {
// acquire returns an *entry from the sync.Pool // acquire returns an *entry from the sync.Pool
func (m *manager) acquire() *item { func (m *manager) acquire() *item {
return m.pool.Get().(*item) return m.pool.Get().(*item) //nolint:forcetypeassert // We store nothing else in the pool
} }
// release and reset *entry to sync.Pool // release and reset *entry to sync.Pool
@ -69,38 +69,47 @@ func (m *manager) release(e *item) {
} }
// get data from storage or memory // get data from storage or memory
func (m *manager) get(key string) (it *item) { func (m *manager) get(key string) *item {
var it *item
if m.storage != nil { if m.storage != nil {
it = m.acquire() it = m.acquire()
if raw, _ := m.storage.Get(key); raw != nil { raw, err := m.storage.Get(key)
if err != nil {
return it
}
if raw != nil {
if _, err := it.UnmarshalMsg(raw); err != nil { if _, err := it.UnmarshalMsg(raw); err != nil {
return return it
} }
} }
return return it
} }
if it, _ = m.memory.Get(key).(*item); it == nil { if it, _ = m.memory.Get(key).(*item); it == nil { //nolint:errcheck // We store nothing else in the pool
it = m.acquire() it = m.acquire()
return it
} }
return return it
} }
// get raw data from storage or memory // get raw data from storage or memory
func (m *manager) getRaw(key string) (raw []byte) { func (m *manager) getRaw(key string) []byte {
var raw []byte
if m.storage != nil { if m.storage != nil {
raw, _ = m.storage.Get(key) raw, _ = m.storage.Get(key) //nolint:errcheck // TODO: Handle error here
} else { } else {
raw, _ = m.memory.Get(key).([]byte) raw, _ = m.memory.Get(key).([]byte) //nolint:errcheck // TODO: Handle error here
} }
return return raw
} }
// set data to storage or memory // set data to storage or memory
func (m *manager) set(key string, it *item, exp time.Duration) { func (m *manager) set(key string, it *item, exp time.Duration) {
if m.storage != nil { if m.storage != nil {
if raw, err := it.MarshalMsg(nil); err == nil { if raw, err := it.MarshalMsg(nil); err == nil {
_ = m.storage.Set(key, raw, exp) _ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Handle error here
} }
// we can release data because it's serialized to database
m.release(it)
} else { } else {
m.memory.Set(key, it, exp) m.memory.Set(key, it, exp)
} }
@ -109,16 +118,16 @@ func (m *manager) set(key string, it *item, exp time.Duration) {
// set data to storage or memory // set data to storage or memory
func (m *manager) setRaw(key string, raw []byte, exp time.Duration) { func (m *manager) setRaw(key string, raw []byte, exp time.Duration) {
if m.storage != nil { if m.storage != nil {
_ = m.storage.Set(key, raw, exp) _ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Handle error here
} else { } else {
m.memory.Set(key, raw, exp) m.memory.Set(key, raw, exp)
} }
} }
// delete data from storage or memory // delete data from storage or memory
func (m *manager) delete(key string) { func (m *manager) del(key string) {
if m.storage != nil { if m.storage != nil {
_ = m.storage.Delete(key) _ = m.storage.Delete(key) //nolint:errcheck // TODO: Handle error here
} else { } else {
m.memory.Delete(key) m.memory.Delete(key)
} }

View File

@ -2,6 +2,7 @@ package compress
import ( import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )

View File

@ -12,8 +12,10 @@ import (
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
) )
//nolint:gochecknoglobals // Using a global var is fine here
var filedata []byte var filedata []byte
//nolint:gochecknoinits // init() is used to populate a global var from a README file
func init() { func init() {
dat, err := os.ReadFile("../../.github/README.md") dat, err := os.ReadFile("../../.github/README.md")
if err != nil { if err != nil {
@ -34,7 +36,7 @@ func Test_Compress_Gzip(t *testing.T) {
return c.Send(filedata) return c.Send(filedata)
}) })
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Set("Accept-Encoding", "gzip") req.Header.Set("Accept-Encoding", "gzip")
resp, err := app.Test(req) resp, err := app.Test(req)
@ -64,7 +66,7 @@ func Test_Compress_Different_Level(t *testing.T) {
return c.Send(filedata) return c.Send(filedata)
}) })
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Set("Accept-Encoding", "gzip") req.Header.Set("Accept-Encoding", "gzip")
resp, err := app.Test(req) resp, err := app.Test(req)
@ -90,7 +92,7 @@ func Test_Compress_Deflate(t *testing.T) {
return c.Send(filedata) return c.Send(filedata)
}) })
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Set("Accept-Encoding", "deflate") req.Header.Set("Accept-Encoding", "deflate")
resp, err := app.Test(req) resp, err := app.Test(req)
@ -114,7 +116,7 @@ func Test_Compress_Brotli(t *testing.T) {
return c.Send(filedata) return c.Send(filedata)
}) })
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Set("Accept-Encoding", "br") req.Header.Set("Accept-Encoding", "br")
resp, err := app.Test(req, 10000) resp, err := app.Test(req, 10000)
@ -138,7 +140,7 @@ func Test_Compress_Disabled(t *testing.T) {
return c.Send(filedata) return c.Send(filedata)
}) })
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Set("Accept-Encoding", "br") req.Header.Set("Accept-Encoding", "br")
resp, err := app.Test(req) resp, err := app.Test(req)
@ -162,7 +164,7 @@ func Test_Compress_Next_Error(t *testing.T) {
return errors.New("next error") return errors.New("next error")
}) })
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Set("Accept-Encoding", "gzip") req.Header.Set("Accept-Encoding", "gzip")
resp, err := app.Test(req) resp, err := app.Test(req)
@ -185,7 +187,7 @@ func Test_Compress_Next(t *testing.T) {
}, },
})) }))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode) utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
} }

View File

@ -33,6 +33,8 @@ const (
) )
// ConfigDefault is the default config // ConfigDefault is the default config
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{ var ConfigDefault = Config{
Next: nil, Next: nil,
Level: LevelDefault, Level: LevelDefault,

View File

@ -1,7 +1,6 @@
package cors package cors
import ( import (
"net/http"
"strconv" "strconv"
"strings" "strings"
@ -54,6 +53,8 @@ type Config struct {
} }
// ConfigDefault is the default config // ConfigDefault is the default config
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{ var ConfigDefault = Config{
Next: nil, Next: nil,
AllowOrigins: "*", AllowOrigins: "*",
@ -128,7 +129,7 @@ func New(config ...Config) fiber.Handler {
} }
// Simple request // Simple request
if c.Method() != http.MethodOptions { if c.Method() != fiber.MethodOptions {
c.Vary(fiber.HeaderOrigin) c.Vary(fiber.HeaderOrigin)
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)

View File

@ -6,6 +6,7 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -237,7 +238,7 @@ func Test_CORS_Next(t *testing.T) {
}, },
})) }))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode) utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
} }

View File

@ -1,6 +1,8 @@
package cors package cors
import "strings" import (
"strings"
)
func matchScheme(domain, pattern string) bool { func matchScheme(domain, pattern string) bool {
didx := strings.Index(domain, ":") didx := strings.Index(domain, ":")
@ -20,18 +22,20 @@ func matchSubdomain(domain, pattern string) bool {
} }
domAuth := domain[didx+3:] domAuth := domain[didx+3:]
// to avoid long loop by invalid long domain // to avoid long loop by invalid long domain
if len(domAuth) > 253 { const maxDomainLen = 253
if len(domAuth) > maxDomainLen {
return false return false
} }
patAuth := pattern[pidx+3:] patAuth := pattern[pidx+3:]
domComp := strings.Split(domAuth, ".") domComp := strings.Split(domAuth, ".")
patComp := strings.Split(patAuth, ".") patComp := strings.Split(patAuth, ".")
for i := len(domComp)/2 - 1; i >= 0; i-- { const divHalf = 2
for i := len(domComp)/divHalf - 1; i >= 0; i-- {
opp := len(domComp) - 1 - i opp := len(domComp) - 1 - i
domComp[i], domComp[opp] = domComp[opp], domComp[i] domComp[i], domComp[opp] = domComp[opp], domComp[i]
} }
for i := len(patComp)/2 - 1; i >= 0; i-- { for i := len(patComp)/divHalf - 1; i >= 0; i-- {
opp := len(patComp) - 1 - i opp := len(patComp) - 1 - i
patComp[i], patComp[opp] = patComp[opp], patComp[i] patComp[i], patComp[opp] = patComp[opp], patComp[i]
} }

View File

@ -1,7 +1,7 @@
package csrf package csrf
import ( import (
"fmt" "log"
"net/textproto" "net/textproto"
"strings" "strings"
"time" "time"
@ -80,13 +80,13 @@ type Config struct {
// Optional. Default: utils.UUID // Optional. Default: utils.UUID
KeyGenerator func() string KeyGenerator func() string
// Deprecated, please use Expiration // Deprecated: Please use Expiration
CookieExpires time.Duration CookieExpires time.Duration
// Deprecated, please use Cookie* related fields // Deprecated: Please use Cookie* related fields
Cookie *fiber.Cookie Cookie *fiber.Cookie
// Deprecated, please use KeyLookup // Deprecated: Please use KeyLookup
TokenLookup string TokenLookup string
// ErrorHandler is executed when an error is returned from fiber.Handler. // ErrorHandler is executed when an error is returned from fiber.Handler.
@ -105,6 +105,8 @@ type Config struct {
const HeaderName = "X-Csrf-Token" const HeaderName = "X-Csrf-Token"
// ConfigDefault is the default config // ConfigDefault is the default config
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{ var ConfigDefault = Config{
KeyLookup: "header:" + HeaderName, KeyLookup: "header:" + HeaderName,
CookieName: "csrf_", CookieName: "csrf_",
@ -116,7 +118,7 @@ var ConfigDefault = Config{
} }
// default ErrorHandler that process return error from fiber.Handler // default ErrorHandler that process return error from fiber.Handler
var defaultErrorHandler = func(c *fiber.Ctx, err error) error { func defaultErrorHandler(_ *fiber.Ctx, _ error) error {
return fiber.ErrForbidden return fiber.ErrForbidden
} }
@ -132,15 +134,15 @@ func configDefault(config ...Config) Config {
// Set default values // Set default values
if cfg.TokenLookup != "" { if cfg.TokenLookup != "" {
fmt.Println("[CSRF] TokenLookup is deprecated, please use KeyLookup") log.Printf("[CSRF] TokenLookup is deprecated, please use KeyLookup\n")
cfg.KeyLookup = cfg.TokenLookup cfg.KeyLookup = cfg.TokenLookup
} }
if int(cfg.CookieExpires.Seconds()) > 0 { if int(cfg.CookieExpires.Seconds()) > 0 {
fmt.Println("[CSRF] CookieExpires is deprecated, please use Expiration") log.Printf("[CSRF] CookieExpires is deprecated, please use Expiration\n")
cfg.Expiration = cfg.CookieExpires cfg.Expiration = cfg.CookieExpires
} }
if cfg.Cookie != nil { if cfg.Cookie != nil {
fmt.Println("[CSRF] Cookie is deprecated, please use Cookie* related fields") log.Printf("[CSRF] Cookie is deprecated, please use Cookie* related fields\n")
if cfg.Cookie.Name != "" { if cfg.Cookie.Name != "" {
cfg.CookieName = cfg.Cookie.Name cfg.CookieName = cfg.Cookie.Name
} }
@ -178,7 +180,8 @@ func configDefault(config ...Config) Config {
// Generate the correct extractor to get the token from the correct location // Generate the correct extractor to get the token from the correct location
selectors := strings.Split(cfg.KeyLookup, ":") selectors := strings.Split(cfg.KeyLookup, ":")
if len(selectors) != 2 { const numParts = 2
if len(selectors) != numParts {
panic("[CSRF] KeyLookup must in the form of <source>:<key>") panic("[CSRF] KeyLookup must in the form of <source>:<key>")
} }

View File

@ -7,9 +7,7 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
) )
var ( var errTokenNotFound = errors.New("csrf token not found")
errTokenNotFound = errors.New("csrf token not found")
)
// New creates a new middleware handler // New creates a new middleware handler
func New(config ...Config) fiber.Handler { func New(config ...Config) fiber.Handler {
@ -22,7 +20,7 @@ func New(config ...Config) fiber.Handler {
dummyValue := []byte{'+'} dummyValue := []byte{'+'}
// Return new handler // Return new handler
return func(c *fiber.Ctx) (err error) { return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true // Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) { if cfg.Next != nil && cfg.Next(c) {
return c.Next() return c.Next()
@ -39,7 +37,7 @@ func New(config ...Config) fiber.Handler {
// Assume that anything not defined as 'safe' by RFC7231 needs protection // Assume that anything not defined as 'safe' by RFC7231 needs protection
// Extract token from client request i.e. header, query, param, form or cookie // Extract token from client request i.e. header, query, param, form or cookie
token, err = cfg.Extractor(c) token, err := cfg.Extractor(c)
if err != nil { if err != nil {
return cfg.ErrorHandler(c, err) return cfg.ErrorHandler(c, err)
} }

View File

@ -7,6 +7,7 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -23,7 +24,7 @@ func Test_CSRF(t *testing.T) {
h := app.Handler() h := app.Handler()
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
methods := [4]string{"GET", "HEAD", "OPTIONS", "TRACE"} methods := [4]string{fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace}
for _, method := range methods { for _, method := range methods {
// Generate CSRF token // Generate CSRF token
@ -33,14 +34,14 @@ func Test_CSRF(t *testing.T) {
// Without CSRF cookie // Without CSRF cookie
ctx.Request.Reset() ctx.Request.Reset()
ctx.Response.Reset() ctx.Response.Reset()
ctx.Request.Header.SetMethod("POST") ctx.Request.Header.SetMethod(fiber.MethodPost)
h(ctx) h(ctx)
utils.AssertEqual(t, 403, ctx.Response.StatusCode()) utils.AssertEqual(t, 403, ctx.Response.StatusCode())
// Empty/invalid CSRF token // Empty/invalid CSRF token
ctx.Request.Reset() ctx.Request.Reset()
ctx.Response.Reset() ctx.Response.Reset()
ctx.Request.Header.SetMethod("POST") ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(HeaderName, "johndoe") ctx.Request.Header.Set(HeaderName, "johndoe")
h(ctx) h(ctx)
utils.AssertEqual(t, 403, ctx.Response.StatusCode()) utils.AssertEqual(t, 403, ctx.Response.StatusCode())
@ -55,7 +56,7 @@ func Test_CSRF(t *testing.T) {
ctx.Request.Reset() ctx.Request.Reset()
ctx.Response.Reset() ctx.Response.Reset()
ctx.Request.Header.SetMethod("POST") ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.Set(HeaderName, token)
h(ctx) h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode()) utils.AssertEqual(t, 200, ctx.Response.StatusCode())
@ -72,7 +73,7 @@ func Test_CSRF_Next(t *testing.T) {
}, },
})) }))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode) utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
} }
@ -92,7 +93,7 @@ func Test_CSRF_Invalid_KeyLookup(t *testing.T) {
h := app.Handler() h := app.Handler()
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("GET") ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx) h(ctx)
} }
@ -110,7 +111,7 @@ func Test_CSRF_From_Form(t *testing.T) {
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
// Invalid CSRF token // Invalid CSRF token
ctx.Request.Header.SetMethod("POST") ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm) ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm)
h(ctx) h(ctx)
utils.AssertEqual(t, 403, ctx.Response.StatusCode()) utils.AssertEqual(t, 403, ctx.Response.StatusCode())
@ -118,12 +119,12 @@ func Test_CSRF_From_Form(t *testing.T) {
// Generate CSRF token // Generate CSRF token
ctx.Request.Reset() ctx.Request.Reset()
ctx.Response.Reset() ctx.Response.Reset()
ctx.Request.Header.SetMethod("GET") ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx) h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1] token = strings.Split(strings.Split(token, ";")[0], "=")[1]
ctx.Request.Header.SetMethod("POST") ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm) ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm)
ctx.Request.SetBodyString("_csrf=" + token) ctx.Request.SetBodyString("_csrf=" + token)
h(ctx) h(ctx)
@ -144,7 +145,7 @@ func Test_CSRF_From_Query(t *testing.T) {
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
// Invalid CSRF token // Invalid CSRF token
ctx.Request.Header.SetMethod("POST") ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.SetRequestURI("/?_csrf=" + utils.UUID()) ctx.Request.SetRequestURI("/?_csrf=" + utils.UUID())
h(ctx) h(ctx)
utils.AssertEqual(t, 403, ctx.Response.StatusCode()) utils.AssertEqual(t, 403, ctx.Response.StatusCode())
@ -152,7 +153,7 @@ func Test_CSRF_From_Query(t *testing.T) {
// Generate CSRF token // Generate CSRF token
ctx.Request.Reset() ctx.Request.Reset()
ctx.Response.Reset() ctx.Response.Reset()
ctx.Request.Header.SetMethod("GET") ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.SetRequestURI("/") ctx.Request.SetRequestURI("/")
h(ctx) h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
@ -161,7 +162,7 @@ func Test_CSRF_From_Query(t *testing.T) {
ctx.Request.Reset() ctx.Request.Reset()
ctx.Response.Reset() ctx.Response.Reset()
ctx.Request.SetRequestURI("/?_csrf=" + token) ctx.Request.SetRequestURI("/?_csrf=" + token)
ctx.Request.Header.SetMethod("POST") ctx.Request.Header.SetMethod(fiber.MethodPost)
h(ctx) h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode()) utils.AssertEqual(t, 200, ctx.Response.StatusCode())
utils.AssertEqual(t, "OK", string(ctx.Response.Body())) utils.AssertEqual(t, "OK", string(ctx.Response.Body()))
@ -181,7 +182,7 @@ func Test_CSRF_From_Param(t *testing.T) {
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
// Invalid CSRF token // Invalid CSRF token
ctx.Request.Header.SetMethod("POST") ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.SetRequestURI("/" + utils.UUID()) ctx.Request.SetRequestURI("/" + utils.UUID())
h(ctx) h(ctx)
utils.AssertEqual(t, 403, ctx.Response.StatusCode()) utils.AssertEqual(t, 403, ctx.Response.StatusCode())
@ -189,7 +190,7 @@ func Test_CSRF_From_Param(t *testing.T) {
// Generate CSRF token // Generate CSRF token
ctx.Request.Reset() ctx.Request.Reset()
ctx.Response.Reset() ctx.Response.Reset()
ctx.Request.Header.SetMethod("GET") ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.SetRequestURI("/" + utils.UUID()) ctx.Request.SetRequestURI("/" + utils.UUID())
h(ctx) h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
@ -198,7 +199,7 @@ func Test_CSRF_From_Param(t *testing.T) {
ctx.Request.Reset() ctx.Request.Reset()
ctx.Response.Reset() ctx.Response.Reset()
ctx.Request.SetRequestURI("/" + token) ctx.Request.SetRequestURI("/" + token)
ctx.Request.Header.SetMethod("POST") ctx.Request.Header.SetMethod(fiber.MethodPost)
h(ctx) h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode()) utils.AssertEqual(t, 200, ctx.Response.StatusCode())
utils.AssertEqual(t, "OK", string(ctx.Response.Body())) utils.AssertEqual(t, "OK", string(ctx.Response.Body()))
@ -218,7 +219,7 @@ func Test_CSRF_From_Cookie(t *testing.T) {
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
// Invalid CSRF token // Invalid CSRF token
ctx.Request.Header.SetMethod("POST") ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.SetRequestURI("/") ctx.Request.SetRequestURI("/")
ctx.Request.Header.Set(fiber.HeaderCookie, "csrf="+utils.UUID()+";") ctx.Request.Header.Set(fiber.HeaderCookie, "csrf="+utils.UUID()+";")
h(ctx) h(ctx)
@ -227,7 +228,7 @@ func Test_CSRF_From_Cookie(t *testing.T) {
// Generate CSRF token // Generate CSRF token
ctx.Request.Reset() ctx.Request.Reset()
ctx.Response.Reset() ctx.Response.Reset()
ctx.Request.Header.SetMethod("GET") ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.SetRequestURI("/") ctx.Request.SetRequestURI("/")
h(ctx) h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
@ -235,7 +236,7 @@ func Test_CSRF_From_Cookie(t *testing.T) {
ctx.Request.Reset() ctx.Request.Reset()
ctx.Response.Reset() ctx.Response.Reset()
ctx.Request.Header.SetMethod("POST") ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderCookie, "csrf="+token+";") ctx.Request.Header.Set(fiber.HeaderCookie, "csrf="+token+";")
ctx.Request.SetRequestURI("/") ctx.Request.SetRequestURI("/")
h(ctx) h(ctx)
@ -268,7 +269,7 @@ func Test_CSRF_From_Custom(t *testing.T) {
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
// Invalid CSRF token // Invalid CSRF token
ctx.Request.Header.SetMethod("POST") ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain) ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain)
h(ctx) h(ctx)
utils.AssertEqual(t, 403, ctx.Response.StatusCode()) utils.AssertEqual(t, 403, ctx.Response.StatusCode())
@ -276,12 +277,12 @@ func Test_CSRF_From_Custom(t *testing.T) {
// Generate CSRF token // Generate CSRF token
ctx.Request.Reset() ctx.Request.Reset()
ctx.Response.Reset() ctx.Response.Reset()
ctx.Request.Header.SetMethod("GET") ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx) h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1] token = strings.Split(strings.Split(token, ";")[0], "=")[1]
ctx.Request.Header.SetMethod("POST") ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain) ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain)
ctx.Request.SetBodyString("_csrf=" + token) ctx.Request.SetBodyString("_csrf=" + token)
h(ctx) h(ctx)
@ -307,13 +308,13 @@ func Test_CSRF_ErrorHandler_InvalidToken(t *testing.T) {
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
// Generate CSRF token // Generate CSRF token
ctx.Request.Header.SetMethod("GET") ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx) h(ctx)
// invalid CSRF token // invalid CSRF token
ctx.Request.Reset() ctx.Request.Reset()
ctx.Response.Reset() ctx.Response.Reset()
ctx.Request.Header.SetMethod("POST") ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(HeaderName, "johndoe") ctx.Request.Header.Set(HeaderName, "johndoe")
h(ctx) h(ctx)
utils.AssertEqual(t, 419, ctx.Response.StatusCode()) utils.AssertEqual(t, 419, ctx.Response.StatusCode())
@ -339,13 +340,13 @@ func Test_CSRF_ErrorHandler_EmptyToken(t *testing.T) {
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
// Generate CSRF token // Generate CSRF token
ctx.Request.Header.SetMethod("GET") ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx) h(ctx)
// empty CSRF token // empty CSRF token
ctx.Request.Reset() ctx.Request.Reset()
ctx.Response.Reset() ctx.Response.Reset()
ctx.Request.Header.SetMethod("POST") ctx.Request.Header.SetMethod(fiber.MethodPost)
h(ctx) h(ctx)
utils.AssertEqual(t, 419, ctx.Response.StatusCode()) utils.AssertEqual(t, 419, ctx.Response.StatusCode())
utils.AssertEqual(t, "empty CSRF token", string(ctx.Response.Body())) utils.AssertEqual(t, "empty CSRF token", string(ctx.Response.Body()))
@ -355,7 +356,7 @@ func Test_CSRF_ErrorHandler_EmptyToken(t *testing.T) {
// func Test_CSRF_UnsafeHeaderValue(t *testing.T) { // func Test_CSRF_UnsafeHeaderValue(t *testing.T) {
// t.Parallel() // t.Parallel()
// app := fiber.New() // app := fiber.New()
//
// app.Use(New()) // app.Use(New())
// app.Get("/", func(c *fiber.Ctx) error { // app.Get("/", func(c *fiber.Ctx) error {
// return c.SendStatus(fiber.StatusOK) // return c.SendStatus(fiber.StatusOK)
@ -366,11 +367,11 @@ func Test_CSRF_ErrorHandler_EmptyToken(t *testing.T) {
// app.Post("/", func(c *fiber.Ctx) error { // app.Post("/", func(c *fiber.Ctx) error {
// return c.SendStatus(fiber.StatusOK) // return c.SendStatus(fiber.StatusOK)
// }) // })
//
// resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil)) // resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
// utils.AssertEqual(t, nil, err) // utils.AssertEqual(t, nil, err)
// utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) // utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
//
// var token string // var token string
// for _, c := range resp.Cookies() { // for _, c := range resp.Cookies() {
// if c.Name != ConfigDefault.CookieName { // if c.Name != ConfigDefault.CookieName {
@ -379,25 +380,25 @@ func Test_CSRF_ErrorHandler_EmptyToken(t *testing.T) {
// token = c.Value // token = c.Value
// break // break
// } // }
//
// fmt.Println("token", token) // fmt.Println("token", token)
//
// getReq := httptest.NewRequest(http.MethodGet, "/", nil) // getReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
// getReq.Header.Set(HeaderName, token) // getReq.Header.Set(HeaderName, token)
// resp, err = app.Test(getReq) // resp, err = app.Test(getReq)
//
// getReq = httptest.NewRequest(http.MethodGet, "/test", nil) // getReq = httptest.NewRequest(fiber.MethodGet, "/test", nil)
// getReq.Header.Set("X-Requested-With", "XMLHttpRequest") // getReq.Header.Set("X-Requested-With", "XMLHttpRequest")
// getReq.Header.Set(fiber.HeaderCacheControl, "no") // getReq.Header.Set(fiber.HeaderCacheControl, "no")
// getReq.Header.Set(HeaderName, token) // getReq.Header.Set(HeaderName, token)
//
// resp, err = app.Test(getReq) // resp, err = app.Test(getReq)
//
// getReq.Header.Set(fiber.HeaderAccept, "*/*") // getReq.Header.Set(fiber.HeaderAccept, "*/*")
// getReq.Header.Del(HeaderName) // getReq.Header.Del(HeaderName)
// resp, err = app.Test(getReq) // resp, err = app.Test(getReq)
//
// postReq := httptest.NewRequest(http.MethodPost, "/", nil) // postReq := httptest.NewRequest(fiber.MethodPost, "/", nil)
// postReq.Header.Set("X-Requested-With", "XMLHttpRequest") // postReq.Header.Set("X-Requested-With", "XMLHttpRequest")
// postReq.Header.Set(HeaderName, token) // postReq.Header.Set(HeaderName, token)
// resp, err = app.Test(postReq) // resp, err = app.Test(postReq)
@ -417,12 +418,12 @@ func Benchmark_Middleware_CSRF_Check(b *testing.B) {
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
// Generate CSRF token // Generate CSRF token
ctx.Request.Header.SetMethod("GET") ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx) h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1] token = strings.Split(strings.Split(token, ";")[0], "=")[1]
ctx.Request.Header.SetMethod("POST") ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.Set(HeaderName, token)
b.ReportAllocs() b.ReportAllocs()
@ -449,7 +450,7 @@ func Benchmark_Middleware_CSRF_GenerateToken(b *testing.B) {
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
// Generate CSRF token // Generate CSRF token
ctx.Request.Header.SetMethod("GET") ctx.Request.Header.SetMethod(fiber.MethodGet)
b.ReportAllocs() b.ReportAllocs()
b.ResetTimer() b.ResetTimer()

View File

@ -41,74 +41,23 @@ func newManager(storage fiber.Storage) *manager {
return manager return manager
} }
// acquire returns an *entry from the sync.Pool
func (m *manager) acquire() *item {
return m.pool.Get().(*item)
}
// release and reset *entry to sync.Pool
func (m *manager) release(e *item) {
// don't release item if we using memory storage
if m.storage != nil {
return
}
m.pool.Put(e)
}
// get data from storage or memory
func (m *manager) get(key string) (it *item) {
if m.storage != nil {
it = m.acquire()
if raw, _ := m.storage.Get(key); raw != nil {
if _, err := it.UnmarshalMsg(raw); err != nil {
return
}
}
return
}
if it, _ = m.memory.Get(key).(*item); it == nil {
it = m.acquire()
}
return
}
// get raw data from storage or memory // get raw data from storage or memory
func (m *manager) getRaw(key string) (raw []byte) { func (m *manager) getRaw(key string) []byte {
var raw []byte
if m.storage != nil { if m.storage != nil {
raw, _ = m.storage.Get(key) raw, _ = m.storage.Get(key) //nolint:errcheck // TODO: Do not ignore error
} else { } else {
raw, _ = m.memory.Get(key).([]byte) raw, _ = m.memory.Get(key).([]byte) //nolint:errcheck // TODO: Do not ignore error
}
return
}
// set data to storage or memory
func (m *manager) set(key string, it *item, exp time.Duration) {
if m.storage != nil {
if raw, err := it.MarshalMsg(nil); err == nil {
_ = m.storage.Set(key, raw, exp)
}
} else {
// the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here
m.memory.Set(utils.CopyString(key), it, exp)
} }
return raw
} }
// set data to storage or memory // set data to storage or memory
func (m *manager) setRaw(key string, raw []byte, exp time.Duration) { func (m *manager) setRaw(key string, raw []byte, exp time.Duration) {
if m.storage != nil { if m.storage != nil {
_ = m.storage.Set(key, raw, exp) _ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Do not ignore error
} else { } else {
// the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here // the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here
m.memory.Set(utils.CopyString(key), raw, exp) m.memory.Set(utils.CopyString(key), raw, exp)
} }
} }
// delete data from storage or memory
func (m *manager) delete(key string) {
if m.storage != nil {
_ = m.storage.Delete(key)
} else {
m.memory.Delete(key)
}
}

View File

@ -1,6 +1,8 @@
package encryptcookie package encryptcookie
import "github.com/gofiber/fiber/v2" import (
"github.com/gofiber/fiber/v2"
)
// Config defines the config for middleware. // Config defines the config for middleware.
type Config struct { type Config struct {
@ -32,6 +34,8 @@ type Config struct {
} }
// ConfigDefault is the default config // ConfigDefault is the default config
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{ var ConfigDefault = Config{
Next: nil, Next: nil,
Except: []string{"csrf_"}, Except: []string{"csrf_"},

View File

@ -2,6 +2,7 @@ package encryptcookie
import ( import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )

View File

@ -7,9 +7,11 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
//nolint:gochecknoglobals // Using a global var is fine here
var testKey = GenerateKey() var testKey = GenerateKey()
func Test_Middleware_Encrypt_Cookie(t *testing.T) { func Test_Middleware_Encrypt_Cookie(t *testing.T) {
@ -35,14 +37,14 @@ func Test_Middleware_Encrypt_Cookie(t *testing.T) {
// Test empty cookie // Test empty cookie
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("GET") ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx) h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode()) utils.AssertEqual(t, 200, ctx.Response.StatusCode())
utils.AssertEqual(t, "value=", string(ctx.Response.Body())) utils.AssertEqual(t, "value=", string(ctx.Response.Body()))
// Test invalid cookie // Test invalid cookie
ctx = &fasthttp.RequestCtx{} ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("GET") ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("test", "Invalid") ctx.Request.Header.SetCookie("test", "Invalid")
h(ctx) h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode()) utils.AssertEqual(t, 200, ctx.Response.StatusCode())
@ -54,18 +56,19 @@ func Test_Middleware_Encrypt_Cookie(t *testing.T) {
// Test valid cookie // Test valid cookie
ctx = &fasthttp.RequestCtx{} ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("POST") ctx.Request.Header.SetMethod(fiber.MethodPost)
h(ctx) h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode()) utils.AssertEqual(t, 200, ctx.Response.StatusCode())
encryptedCookie := fasthttp.Cookie{} encryptedCookie := fasthttp.Cookie{}
encryptedCookie.SetKey("test") encryptedCookie.SetKey("test")
utils.AssertEqual(t, true, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value") utils.AssertEqual(t, true, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value")
decryptedCookieValue, _ := DecryptCookie(string(encryptedCookie.Value()), testKey) decryptedCookieValue, err := DecryptCookie(string(encryptedCookie.Value()), testKey)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "SomeThing", decryptedCookieValue) utils.AssertEqual(t, "SomeThing", decryptedCookieValue)
ctx = &fasthttp.RequestCtx{} ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("GET") ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("test", string(encryptedCookie.Value())) ctx.Request.Header.SetCookie("test", string(encryptedCookie.Value()))
h(ctx) h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode()) utils.AssertEqual(t, 200, ctx.Response.StatusCode())
@ -91,7 +94,7 @@ func Test_Encrypt_Cookie_Next(t *testing.T) {
return nil return nil
}) })
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "SomeThing", resp.Cookies()[0].Value) utils.AssertEqual(t, "SomeThing", resp.Cookies()[0].Value)
} }
@ -123,7 +126,7 @@ func Test_Encrypt_Cookie_Except(t *testing.T) {
h := app.Handler() h := app.Handler()
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("GET") ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx) h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode()) utils.AssertEqual(t, 200, ctx.Response.StatusCode())
@ -135,7 +138,8 @@ func Test_Encrypt_Cookie_Except(t *testing.T) {
encryptedCookie := fasthttp.Cookie{} encryptedCookie := fasthttp.Cookie{}
encryptedCookie.SetKey("test2") encryptedCookie.SetKey("test2")
utils.AssertEqual(t, true, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value") utils.AssertEqual(t, true, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value")
decryptedCookieValue, _ := DecryptCookie(string(encryptedCookie.Value()), testKey) decryptedCookieValue, err := DecryptCookie(string(encryptedCookie.Value()), testKey)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "SomeThing", decryptedCookieValue) utils.AssertEqual(t, "SomeThing", decryptedCookieValue)
} }
@ -169,18 +173,19 @@ func Test_Encrypt_Cookie_Custom_Encryptor(t *testing.T) {
h := app.Handler() h := app.Handler()
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("POST") ctx.Request.Header.SetMethod(fiber.MethodPost)
h(ctx) h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode()) utils.AssertEqual(t, 200, ctx.Response.StatusCode())
encryptedCookie := fasthttp.Cookie{} encryptedCookie := fasthttp.Cookie{}
encryptedCookie.SetKey("test") encryptedCookie.SetKey("test")
utils.AssertEqual(t, true, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value") utils.AssertEqual(t, true, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value")
decodedBytes, _ := base64.StdEncoding.DecodeString(string(encryptedCookie.Value())) decodedBytes, err := base64.StdEncoding.DecodeString(string(encryptedCookie.Value()))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "SomeThing", string(decodedBytes)) utils.AssertEqual(t, "SomeThing", string(decodedBytes))
ctx = &fasthttp.RequestCtx{} ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("GET") ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("test", string(encryptedCookie.Value())) ctx.Request.Header.SetCookie("test", string(encryptedCookie.Value()))
h(ctx) h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode()) utils.AssertEqual(t, 200, ctx.Response.StatusCode())

View File

@ -6,47 +6,56 @@ import (
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt"
"io" "io"
) )
// EncryptCookie Encrypts a cookie value with specific encryption key // EncryptCookie Encrypts a cookie value with specific encryption key
func EncryptCookie(value, key string) (string, error) { func EncryptCookie(value, key string) (string, error) {
keyDecoded, _ := base64.StdEncoding.DecodeString(key) keyDecoded, err := base64.StdEncoding.DecodeString(key)
plaintext := []byte(value) if err != nil {
return "", fmt.Errorf("failed to base64-decode key: %w", err)
}
block, err := aes.NewCipher(keyDecoded) block, err := aes.NewCipher(keyDecoded)
if err != nil { if err != nil {
return "", err return "", fmt.Errorf("failed to create AES cipher: %w", err)
} }
gcm, err := cipher.NewGCM(block) gcm, err := cipher.NewGCM(block)
if err != nil { if err != nil {
return "", err return "", fmt.Errorf("failed to create GCM mode: %w", err)
} }
nonce := make([]byte, gcm.NonceSize()) nonce := make([]byte, gcm.NonceSize())
if _, err = io.ReadFull(rand.Reader, nonce); err != nil { if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
return "", err return "", fmt.Errorf("failed to read: %w", err)
} }
ciphertext := gcm.Seal(nonce, nonce, plaintext, nil) ciphertext := gcm.Seal(nonce, nonce, []byte(value), nil)
return base64.StdEncoding.EncodeToString(ciphertext), nil return base64.StdEncoding.EncodeToString(ciphertext), nil
} }
// DecryptCookie Decrypts a cookie value with specific encryption key // DecryptCookie Decrypts a cookie value with specific encryption key
func DecryptCookie(value, key string) (string, error) { func DecryptCookie(value, key string) (string, error) {
keyDecoded, _ := base64.StdEncoding.DecodeString(key) keyDecoded, err := base64.StdEncoding.DecodeString(key)
enc, _ := base64.StdEncoding.DecodeString(value) if err != nil {
return "", fmt.Errorf("failed to base64-decode key: %w", err)
}
enc, err := base64.StdEncoding.DecodeString(value)
if err != nil {
return "", fmt.Errorf("failed to base64-decode value: %w", err)
}
block, err := aes.NewCipher(keyDecoded) block, err := aes.NewCipher(keyDecoded)
if err != nil { if err != nil {
return "", err return "", fmt.Errorf("failed to create AES cipher: %w", err)
} }
gcm, err := cipher.NewGCM(block) gcm, err := cipher.NewGCM(block)
if err != nil { if err != nil {
return "", err return "", fmt.Errorf("failed to create GCM mode: %w", err)
} }
nonceSize := gcm.NonceSize() nonceSize := gcm.NonceSize()
@ -59,7 +68,7 @@ func DecryptCookie(value, key string) (string, error) {
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil { if err != nil {
return "", err return "", fmt.Errorf("failed to decrypt ciphertext: %w", err)
} }
return string(plaintext), nil return string(plaintext), nil
@ -67,7 +76,8 @@ func DecryptCookie(value, key string) (string, error) {
// GenerateKey Generates an encryption key // GenerateKey Generates an encryption key
func GenerateKey() string { func GenerateKey() string {
ret := make([]byte, 32) const keyLen = 32
ret := make([]byte, keyLen)
if _, err := rand.Read(ret); err != nil { if _, err := rand.Read(ret); err != nil {
panic(err) panic(err)

View File

@ -23,10 +23,8 @@ func (envVar *EnvVar) set(key, val string) {
envVar.Vars[key] = val envVar.Vars[key] = val
} }
var defaultConfig = Config{}
func New(config ...Config) fiber.Handler { func New(config ...Config) fiber.Handler {
var cfg = defaultConfig var cfg Config
if len(config) > 0 { if len(config) > 0 {
cfg = config[0] cfg = config[0]
} }
@ -57,8 +55,9 @@ func newEnvVar(cfg Config) *EnvVar {
} }
} }
} else { } else {
const numElems = 2
for _, envVal := range os.Environ() { for _, envVal := range os.Environ() {
keyVal := strings.SplitN(envVal, "=", 2) keyVal := strings.SplitN(envVal, "=", numElems)
if _, exists := cfg.ExcludeVars[keyVal[0]]; !exists { if _, exists := cfg.ExcludeVars[keyVal[0]]; !exists {
vars.set(keyVal[0], keyVal[1]) vars.set(keyVal[0], keyVal[1])
} }

View File

@ -1,6 +1,8 @@
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
package envvar package envvar
import ( import (
"context"
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
@ -12,16 +14,25 @@ import (
) )
func TestEnvVarStructWithExportVarsExcludeVars(t *testing.T) { func TestEnvVarStructWithExportVarsExcludeVars(t *testing.T) {
os.Setenv("testKey", "testEnvValue") err := os.Setenv("testKey", "testEnvValue")
os.Setenv("anotherEnvKey", "anotherEnvVal") utils.AssertEqual(t, nil, err)
os.Setenv("excludeKey", "excludeEnvValue") err = os.Setenv("anotherEnvKey", "anotherEnvVal")
defer os.Unsetenv("testKey") utils.AssertEqual(t, nil, err)
defer os.Unsetenv("anotherEnvKey") err = os.Setenv("excludeKey", "excludeEnvValue")
defer os.Unsetenv("excludeKey") utils.AssertEqual(t, nil, err)
defer func() {
err := os.Unsetenv("testKey")
utils.AssertEqual(t, nil, err)
err = os.Unsetenv("anotherEnvKey")
utils.AssertEqual(t, nil, err)
err = os.Unsetenv("excludeKey")
utils.AssertEqual(t, nil, err)
}()
vars := newEnvVar(Config{ vars := newEnvVar(Config{
ExportVars: map[string]string{"testKey": "", "testDefaultKey": "testDefaultVal"}, ExportVars: map[string]string{"testKey": "", "testDefaultKey": "testDefaultVal"},
ExcludeVars: map[string]string{"excludeKey": ""}}) ExcludeVars: map[string]string{"excludeKey": ""},
})
utils.AssertEqual(t, vars.Vars["testKey"], "testEnvValue") utils.AssertEqual(t, vars.Vars["testKey"], "testEnvValue")
utils.AssertEqual(t, vars.Vars["testDefaultKey"], "testDefaultVal") utils.AssertEqual(t, vars.Vars["testDefaultKey"], "testDefaultVal")
@ -30,21 +41,28 @@ func TestEnvVarStructWithExportVarsExcludeVars(t *testing.T) {
} }
func TestEnvVarHandler(t *testing.T) { func TestEnvVarHandler(t *testing.T) {
os.Setenv("testKey", "testVal") err := os.Setenv("testKey", "testVal")
defer os.Unsetenv("testKey") utils.AssertEqual(t, nil, err)
defer func() {
err := os.Unsetenv("testKey")
utils.AssertEqual(t, nil, err)
}()
expectedEnvVarResponse, _ := json.Marshal( expectedEnvVarResponse, err := json.Marshal(
struct { struct {
Vars map[string]string `json:"vars"` Vars map[string]string `json:"vars"`
}{ }{
map[string]string{"testKey": "testVal"}, map[string]string{"testKey": "testVal"},
}) })
utils.AssertEqual(t, nil, err)
app := fiber.New() app := fiber.New()
app.Use("/envvars", New(Config{ app.Use("/envvars", New(Config{
ExportVars: map[string]string{"testKey": ""}})) ExportVars: map[string]string{"testKey": ""},
}))
req, _ := http.NewRequest("GET", "http://localhost/envvars", nil) req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars", nil)
utils.AssertEqual(t, nil, err)
resp, err := app.Test(req) resp, err := app.Test(req)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
@ -57,14 +75,16 @@ func TestEnvVarHandler(t *testing.T) {
func TestEnvVarHandlerNotMatched(t *testing.T) { func TestEnvVarHandlerNotMatched(t *testing.T) {
app := fiber.New() app := fiber.New()
app.Use("/envvars", New(Config{ app.Use("/envvars", New(Config{
ExportVars: map[string]string{"testKey": ""}})) ExportVars: map[string]string{"testKey": ""},
}))
app.Get("/another-path", func(ctx *fiber.Ctx) error { app.Get("/another-path", func(ctx *fiber.Ctx) error {
utils.AssertEqual(t, nil, ctx.SendString("OK")) utils.AssertEqual(t, nil, ctx.SendString("OK"))
return nil return nil
}) })
req, _ := http.NewRequest("GET", "http://localhost/another-path", nil) req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/another-path", nil)
utils.AssertEqual(t, nil, err)
resp, err := app.Test(req) resp, err := app.Test(req)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
@ -75,13 +95,18 @@ func TestEnvVarHandlerNotMatched(t *testing.T) {
} }
func TestEnvVarHandlerDefaultConfig(t *testing.T) { func TestEnvVarHandlerDefaultConfig(t *testing.T) {
os.Setenv("testEnvKey", "testEnvVal") err := os.Setenv("testEnvKey", "testEnvVal")
defer os.Unsetenv("testEnvKey") utils.AssertEqual(t, nil, err)
defer func() {
err := os.Unsetenv("testEnvKey")
utils.AssertEqual(t, nil, err)
}()
app := fiber.New() app := fiber.New()
app.Use("/envvars", New()) app.Use("/envvars", New())
req, _ := http.NewRequest("GET", "http://localhost/envvars", nil) req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars", nil)
utils.AssertEqual(t, nil, err)
resp, err := app.Test(req) resp, err := app.Test(req)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
@ -98,7 +123,8 @@ func TestEnvVarHandlerMethod(t *testing.T) {
app := fiber.New() app := fiber.New()
app.Use("/envvars", New()) app.Use("/envvars", New())
req, _ := http.NewRequest("POST", "http://localhost/envvars", nil) req, err := http.NewRequestWithContext(context.Background(), fiber.MethodPost, "http://localhost/envvars", nil)
utils.AssertEqual(t, nil, err)
resp, err := app.Test(req) resp, err := app.Test(req)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusMethodNotAllowed, resp.StatusCode) utils.AssertEqual(t, fiber.StatusMethodNotAllowed, resp.StatusCode)
@ -107,14 +133,19 @@ func TestEnvVarHandlerMethod(t *testing.T) {
func TestEnvVarHandlerSpecialValue(t *testing.T) { func TestEnvVarHandlerSpecialValue(t *testing.T) {
testEnvKey := "testEnvKey" testEnvKey := "testEnvKey"
fakeBase64 := "testBase64:TQ==" fakeBase64 := "testBase64:TQ=="
os.Setenv(testEnvKey, fakeBase64) err := os.Setenv(testEnvKey, fakeBase64)
defer os.Unsetenv(testEnvKey) utils.AssertEqual(t, nil, err)
defer func() {
err := os.Unsetenv(testEnvKey)
utils.AssertEqual(t, nil, err)
}()
app := fiber.New() app := fiber.New()
app.Use("/envvars", New()) app.Use("/envvars", New())
app.Use("/envvars/export", New(Config{ExportVars: map[string]string{testEnvKey: ""}})) app.Use("/envvars/export", New(Config{ExportVars: map[string]string{testEnvKey: ""}}))
req, _ := http.NewRequest("GET", "http://localhost/envvars", nil) req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars", nil)
utils.AssertEqual(t, nil, err)
resp, err := app.Test(req) resp, err := app.Test(req)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
@ -126,7 +157,8 @@ func TestEnvVarHandlerSpecialValue(t *testing.T) {
val := envVars.Vars[testEnvKey] val := envVars.Vars[testEnvKey]
utils.AssertEqual(t, fakeBase64, val) utils.AssertEqual(t, fakeBase64, val)
req, _ = http.NewRequest("GET", "http://localhost/envvars/export", nil) req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars/export", nil)
utils.AssertEqual(t, nil, err)
resp, err = app.Test(req) resp, err = app.Test(req)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)

View File

@ -23,6 +23,8 @@ type Config struct {
} }
// ConfigDefault is the default config // ConfigDefault is the default config
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{ var ConfigDefault = Config{
Weak: false, Weak: false,
Next: nil, Next: nil,

View File

@ -5,12 +5,8 @@ import (
"hash/crc32" "hash/crc32"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/valyala/bytebufferpool"
)
var ( "github.com/valyala/bytebufferpool"
normalizedHeaderETag = []byte("Etag")
weakPrefix = []byte("W/")
) )
// New creates a new middleware handler // New creates a new middleware handler
@ -18,32 +14,38 @@ func New(config ...Config) fiber.Handler {
// Set default config // Set default config
cfg := configDefault(config...) cfg := configDefault(config...)
crc32q := crc32.MakeTable(0xD5828281) var (
normalizedHeaderETag = []byte("Etag")
weakPrefix = []byte("W/")
)
const crcPol = 0xD5828281
crc32q := crc32.MakeTable(crcPol)
// Return new handler // Return new handler
return func(c *fiber.Ctx) (err error) { return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true // Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) { if cfg.Next != nil && cfg.Next(c) {
return c.Next() return c.Next()
} }
// Return err if next handler returns one // Return err if next handler returns one
if err = c.Next(); err != nil { if err := c.Next(); err != nil {
return return err
} }
// Don't generate ETags for invalid responses // Don't generate ETags for invalid responses
if c.Response().StatusCode() != fiber.StatusOK { if c.Response().StatusCode() != fiber.StatusOK {
return return nil
} }
body := c.Response().Body() body := c.Response().Body()
// Skips ETag if no response body is present // Skips ETag if no response body is present
if len(body) == 0 { if len(body) == 0 {
return return nil
} }
// Skip ETag if header is already present // Skip ETag if header is already present
if c.Response().Header.PeekBytes(normalizedHeaderETag) != nil { if c.Response().Header.PeekBytes(normalizedHeaderETag) != nil {
return return nil
} }
// Generate ETag for response // Generate ETag for response
@ -52,14 +54,14 @@ func New(config ...Config) fiber.Handler {
// Enable weak tag // Enable weak tag
if cfg.Weak { if cfg.Weak {
_, _ = bb.Write(weakPrefix) _, _ = bb.Write(weakPrefix) //nolint:errcheck // This will never fail
} }
_ = bb.WriteByte('"') _ = bb.WriteByte('"') //nolint:errcheck // This will never fail
bb.B = appendUint(bb.Bytes(), uint32(len(body))) bb.B = appendUint(bb.Bytes(), uint32(len(body)))
_ = bb.WriteByte('-') _ = bb.WriteByte('-') //nolint:errcheck // This will never fail
bb.B = appendUint(bb.Bytes(), crc32.Checksum(body, crc32q)) bb.B = appendUint(bb.Bytes(), crc32.Checksum(body, crc32q))
_ = bb.WriteByte('"') _ = bb.WriteByte('"') //nolint:errcheck // This will never fail
etag := bb.Bytes() etag := bb.Bytes()
@ -78,7 +80,7 @@ func New(config ...Config) fiber.Handler {
// W/1 != W/2 || W/1 != 2 // W/1 != W/2 || W/1 != 2
c.Response().Header.SetCanonical(normalizedHeaderETag, etag) c.Response().Header.SetCanonical(normalizedHeaderETag, etag)
return return nil
} }
if bytes.Contains(clientEtag, etag) { if bytes.Contains(clientEtag, etag) {
@ -90,7 +92,7 @@ func New(config ...Config) fiber.Handler {
// 1 != 2 // 1 != 2
c.Response().Header.SetCanonical(normalizedHeaderETag, etag) c.Response().Header.SetCanonical(normalizedHeaderETag, etag)
return return nil
} }
} }
@ -102,7 +104,7 @@ func appendUint(dst []byte, n uint32) []byte {
var q uint32 var q uint32
for n >= 10 { for n >= 10 {
i-- i--
q = n / 10 q = n / 10 //nolint:gomnd // TODO: Explain why we divide by 10 here
buf[i] = '0' + byte(n-q*10) buf[i] = '0' + byte(n-q*10)
n = q n = q
} }

View File

@ -8,6 +8,7 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -21,7 +22,7 @@ func Test_ETag_Next(t *testing.T) {
}, },
})) }))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode) utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
} }
@ -37,7 +38,7 @@ func Test_ETag_SkipError(t *testing.T) {
return fiber.ErrForbidden return fiber.ErrForbidden
}) })
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusForbidden, resp.StatusCode) utils.AssertEqual(t, fiber.StatusForbidden, resp.StatusCode)
} }
@ -53,7 +54,7 @@ func Test_ETag_NotStatusOK(t *testing.T) {
return c.SendStatus(fiber.StatusCreated) return c.SendStatus(fiber.StatusCreated)
}) })
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusCreated, resp.StatusCode) utils.AssertEqual(t, fiber.StatusCreated, resp.StatusCode)
} }
@ -69,7 +70,7 @@ func Test_ETag_NoBody(t *testing.T) {
return nil return nil
}) })
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
} }
@ -91,7 +92,7 @@ func Test_ETag_NewEtag(t *testing.T) {
}) })
} }
func testETagNewEtag(t *testing.T, headerIfNoneMatch, matched bool) { func testETagNewEtag(t *testing.T, headerIfNoneMatch, matched bool) { //nolint:revive // We're in a test, so using bools as a flow-control is fine
t.Helper() t.Helper()
app := fiber.New() app := fiber.New()
@ -102,7 +103,7 @@ func testETagNewEtag(t *testing.T, headerIfNoneMatch, matched bool) {
return c.SendString("Hello, World!") return c.SendString("Hello, World!")
}) })
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(fiber.MethodGet, "/", nil)
if headerIfNoneMatch { if headerIfNoneMatch {
etag := `"non-match"` etag := `"non-match"`
if matched { if matched {
@ -145,7 +146,7 @@ func Test_ETag_WeakEtag(t *testing.T) {
}) })
} }
func testETagWeakEtag(t *testing.T, headerIfNoneMatch, matched bool) { func testETagWeakEtag(t *testing.T, headerIfNoneMatch, matched bool) { //nolint:revive // We're in a test, so using bools as a flow-control is fine
t.Helper() t.Helper()
app := fiber.New() app := fiber.New()
@ -156,7 +157,7 @@ func testETagWeakEtag(t *testing.T, headerIfNoneMatch, matched bool) {
return c.SendString("Hello, World!") return c.SendString("Hello, World!")
}) })
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(fiber.MethodGet, "/", nil)
if headerIfNoneMatch { if headerIfNoneMatch {
etag := `W/"non-match"` etag := `W/"non-match"`
if matched { if matched {
@ -199,7 +200,7 @@ func Test_ETag_CustomEtag(t *testing.T) {
}) })
} }
func testETagCustomEtag(t *testing.T, headerIfNoneMatch, matched bool) { func testETagCustomEtag(t *testing.T, headerIfNoneMatch, matched bool) { //nolint:revive // We're in a test, so using bools as a flow-control is fine
t.Helper() t.Helper()
app := fiber.New() app := fiber.New()
@ -214,7 +215,7 @@ func testETagCustomEtag(t *testing.T, headerIfNoneMatch, matched bool) {
return c.SendString("Hello, World!") return c.SendString("Hello, World!")
}) })
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(fiber.MethodGet, "/", nil)
if headerIfNoneMatch { if headerIfNoneMatch {
etag := `"non-match"` etag := `"non-match"`
if matched { if matched {
@ -255,7 +256,7 @@ func Test_ETag_CustomEtagPut(t *testing.T) {
return c.SendString("Hello, World!") return c.SendString("Hello, World!")
}) })
req := httptest.NewRequest("PUT", "/", nil) req := httptest.NewRequest(fiber.MethodPut, "/", nil)
req.Header.Set(fiber.HeaderIfMatch, `"non-match"`) req.Header.Set(fiber.HeaderIfMatch, `"non-match"`)
resp, err := app.Test(req) resp, err := app.Test(req)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
@ -275,7 +276,7 @@ func Benchmark_Etag(b *testing.B) {
h := app.Handler() h := app.Handler()
fctx := &fasthttp.RequestCtx{} fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod("GET") fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/") fctx.Request.SetRequestURI("/")
b.ReportAllocs() b.ReportAllocs()

View File

@ -1,6 +1,8 @@
package expvar package expvar
import "github.com/gofiber/fiber/v2" import (
"github.com/gofiber/fiber/v2"
)
// Config defines the config for middleware. // Config defines the config for middleware.
type Config struct { type Config struct {
@ -10,6 +12,7 @@ type Config struct {
Next func(c *fiber.Ctx) bool Next func(c *fiber.Ctx) bool
} }
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{ var ConfigDefault = Config{
Next: nil, Next: nil,
} }

View File

@ -4,6 +4,7 @@ import (
"strings" "strings"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/valyala/fasthttp/expvarhandler" "github.com/valyala/fasthttp/expvarhandler"
) )
@ -29,6 +30,6 @@ func New(config ...Config) fiber.Handler {
return nil return nil
} }
return c.Redirect("/debug/vars", 302) return c.Redirect("/debug/vars", fiber.StatusFound)
} }
} }

View File

@ -34,6 +34,8 @@ type Config struct {
} }
// ConfigDefault is the default config // ConfigDefault is the default config
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{ var ConfigDefault = Config{
Next: nil, Next: nil,
File: "", File: "",

View File

@ -1,16 +1,18 @@
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
package favicon package favicon
import ( import (
"fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
"strings" "strings"
"testing" "testing"
"github.com/valyala/fasthttp"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
) )
// go test -run Test_Middleware_Favicon // go test -run Test_Middleware_Favicon
@ -25,22 +27,22 @@ func Test_Middleware_Favicon(t *testing.T) {
}) })
// Skip Favicon middleware // Skip Favicon middleware
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code") utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
resp, err = app.Test(httptest.NewRequest("GET", "/favicon.ico", nil)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusNoContent, resp.StatusCode, "Status code") utils.AssertEqual(t, fiber.StatusNoContent, resp.StatusCode, "Status code")
resp, err = app.Test(httptest.NewRequest("OPTIONS", "/favicon.ico", nil)) resp, err = app.Test(httptest.NewRequest(fiber.MethodOptions, "/favicon.ico", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code") utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
resp, err = app.Test(httptest.NewRequest("PUT", "/favicon.ico", nil)) resp, err = app.Test(httptest.NewRequest(fiber.MethodPut, "/favicon.ico", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusMethodNotAllowed, resp.StatusCode, "Status code") utils.AssertEqual(t, fiber.StatusMethodNotAllowed, resp.StatusCode, "Status code")
utils.AssertEqual(t, "GET, HEAD, OPTIONS", resp.Header.Get(fiber.HeaderAllow)) utils.AssertEqual(t, strings.Join([]string{fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions}, ", "), resp.Header.Get(fiber.HeaderAllow))
} }
// go test -run Test_Middleware_Favicon_Not_Found // go test -run Test_Middleware_Favicon_Not_Found
@ -70,8 +72,7 @@ func Test_Middleware_Favicon_Found(t *testing.T) {
return nil return nil
}) })
resp, err := app.Test(httptest.NewRequest("GET", "/favicon.ico", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code") utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
utils.AssertEqual(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType)) utils.AssertEqual(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType))
@ -83,15 +84,15 @@ func Test_Middleware_Favicon_Found(t *testing.T) {
// TODO use os.Dir if fiber upgrades to 1.16 // TODO use os.Dir if fiber upgrades to 1.16
type mockFS struct{} type mockFS struct{}
func (m mockFS) Open(name string) (http.File, error) { func (mockFS) Open(name string) (http.File, error) {
if name == "/" { if name == "/" {
name = "." name = "."
} else { } else {
name = strings.TrimPrefix(name, "/") name = strings.TrimPrefix(name, "/")
} }
file, err := os.Open(name) file, err := os.Open(name) //nolint:gosec // We're in a test func, so this is fine
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to open: %w", err)
} }
return file, nil return file, nil
} }
@ -106,7 +107,7 @@ func Test_Middleware_Favicon_FileSystem(t *testing.T) {
FileSystem: mockFS{}, FileSystem: mockFS{},
})) }))
resp, err := app.Test(httptest.NewRequest("GET", "/favicon.ico", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code") utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
utils.AssertEqual(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType)) utils.AssertEqual(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType))
@ -123,7 +124,7 @@ func Test_Middleware_Favicon_CacheControl(t *testing.T) {
File: "../../.github/testdata/favicon.ico", File: "../../.github/testdata/favicon.ico",
})) }))
resp, err := app.Test(httptest.NewRequest("GET", "/favicon.ico", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code") utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
utils.AssertEqual(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType)) utils.AssertEqual(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType))
@ -159,7 +160,7 @@ func Test_Favicon_Next(t *testing.T) {
}, },
})) }))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode) utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
} }

View File

@ -1,6 +1,7 @@
package filesystem package filesystem
import ( import (
"fmt"
"net/http" "net/http"
"os" "os"
"strconv" "strconv"
@ -55,6 +56,8 @@ type Config struct {
} }
// ConfigDefault is the default config // ConfigDefault is the default config
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{ var ConfigDefault = Config{
Next: nil, Next: nil,
Root: nil, Root: nil,
@ -102,7 +105,7 @@ func New(config ...Config) fiber.Handler {
cacheControlStr := "public, max-age=" + strconv.Itoa(cfg.MaxAge) cacheControlStr := "public, max-age=" + strconv.Itoa(cfg.MaxAge)
// Return new handler // Return new handler
return func(c *fiber.Ctx) (err error) { return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true // Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) { if cfg.Next != nil && cfg.Next(c) {
return c.Next() return c.Next()
@ -131,28 +134,23 @@ func New(config ...Config) fiber.Handler {
path = cfg.PathPrefix + path path = cfg.PathPrefix + path
} }
var (
file http.File
stat os.FileInfo
)
if len(path) > 1 { if len(path) > 1 {
path = utils.TrimRight(path, '/') path = utils.TrimRight(path, '/')
} }
file, err = cfg.Root.Open(path) file, err := cfg.Root.Open(path)
if err != nil && os.IsNotExist(err) && cfg.NotFoundFile != "" { if err != nil && os.IsNotExist(err) && cfg.NotFoundFile != "" {
file, err = cfg.Root.Open(cfg.NotFoundFile) file, err = cfg.Root.Open(cfg.NotFoundFile)
} }
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
return c.Status(fiber.StatusNotFound).Next() return c.Status(fiber.StatusNotFound).Next()
} }
return return fmt.Errorf("failed to open: %w", err)
} }
if stat, err = file.Stat(); err != nil { stat, err := file.Stat()
return if err != nil {
return fmt.Errorf("failed to stat: %w", err)
} }
// Serve index if path is directory // Serve index if path is directory
@ -200,7 +198,7 @@ func New(config ...Config) fiber.Handler {
c.Response().SkipBody = true c.Response().SkipBody = true
c.Response().Header.SetContentLength(contentLength) c.Response().Header.SetContentLength(contentLength)
if err := file.Close(); err != nil { if err := file.Close(); err != nil {
return err return fmt.Errorf("failed to close: %w", err)
} }
return nil return nil
} }
@ -210,22 +208,18 @@ func New(config ...Config) fiber.Handler {
} }
// SendFile ... // SendFile ...
func SendFile(c *fiber.Ctx, fs http.FileSystem, path string) (err error) { func SendFile(c *fiber.Ctx, fs http.FileSystem, path string) error {
var ( file, err := fs.Open(path)
file http.File
stat os.FileInfo
)
file, err = fs.Open(path)
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
return fiber.ErrNotFound return fiber.ErrNotFound
} }
return err return fmt.Errorf("failed to open: %w", err)
} }
if stat, err = file.Stat(); err != nil { stat, err := file.Stat()
return err if err != nil {
return fmt.Errorf("failed to stat: %w", err)
} }
// Serve index if path is directory // Serve index if path is directory
@ -268,7 +262,7 @@ func SendFile(c *fiber.Ctx, fs http.FileSystem, path string) (err error) {
c.Response().SkipBody = true c.Response().SkipBody = true
c.Response().Header.SetContentLength(contentLength) c.Response().Header.SetContentLength(contentLength)
if err := file.Close(); err != nil { if err := file.Close(); err != nil {
return err return fmt.Errorf("failed to close: %w", err)
} }
return nil return nil
} }

View File

@ -1,6 +1,8 @@
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
package filesystem package filesystem
import ( import (
"context"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -119,7 +121,7 @@ func Test_FileSystem(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel() t.Parallel()
resp, err := app.Test(httptest.NewRequest("GET", tt.url, nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, tt.url, nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, tt.statusCode, resp.StatusCode) utils.AssertEqual(t, tt.statusCode, resp.StatusCode)
@ -142,7 +144,7 @@ func Test_FileSystem_Next(t *testing.T) {
}, },
})) }))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode) utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
} }
@ -168,7 +170,8 @@ func Test_FileSystem_Head(t *testing.T) {
Root: http.Dir("../../.github/testdata/fs"), Root: http.Dir("../../.github/testdata/fs"),
})) }))
req, _ := http.NewRequest(fiber.MethodHead, "/test", nil) req, err := http.NewRequestWithContext(context.Background(), fiber.MethodHead, "/test", nil)
utils.AssertEqual(t, nil, err)
resp, err := app.Test(req) resp, err := app.Test(req)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode) utils.AssertEqual(t, 200, resp.StatusCode)
@ -182,7 +185,8 @@ func Test_FileSystem_NoRoot(t *testing.T) {
app := fiber.New() app := fiber.New()
app.Use(New()) app.Use(New())
_, _ = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) _, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
} }
func Test_FileSystem_UsingParam(t *testing.T) { func Test_FileSystem_UsingParam(t *testing.T) {
@ -193,7 +197,8 @@ func Test_FileSystem_UsingParam(t *testing.T) {
return SendFile(c, http.Dir("../../.github/testdata/fs"), c.Params("path")+".html") return SendFile(c, http.Dir("../../.github/testdata/fs"), c.Params("path")+".html")
}) })
req, _ := http.NewRequest(fiber.MethodHead, "/index", nil) req, err := http.NewRequestWithContext(context.Background(), fiber.MethodHead, "/index", nil)
utils.AssertEqual(t, nil, err)
resp, err := app.Test(req) resp, err := app.Test(req)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode) utils.AssertEqual(t, 200, resp.StatusCode)
@ -207,7 +212,8 @@ func Test_FileSystem_UsingParam_NonFile(t *testing.T) {
return SendFile(c, http.Dir("../../.github/testdata/fs"), c.Params("path")+".html") return SendFile(c, http.Dir("../../.github/testdata/fs"), c.Params("path")+".html")
}) })
req, _ := http.NewRequest(fiber.MethodHead, "/template", nil) req, err := http.NewRequestWithContext(context.Background(), fiber.MethodHead, "/template", nil)
utils.AssertEqual(t, nil, err)
resp, err := app.Test(req) resp, err := app.Test(req)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 404, resp.StatusCode) utils.AssertEqual(t, 404, resp.StatusCode)

View File

@ -13,18 +13,18 @@ import (
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
) )
func getFileExtension(path string) string { func getFileExtension(p string) string {
n := strings.LastIndexByte(path, '.') n := strings.LastIndexByte(p, '.')
if n < 0 { if n < 0 {
return "" return ""
} }
return path[n:] return p[n:]
} }
func dirList(c *fiber.Ctx, f http.File) error { func dirList(c *fiber.Ctx, f http.File) error {
fileinfos, err := f.Readdir(-1) fileinfos, err := f.Readdir(-1)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to read dir: %w", err)
} }
fm := make(map[string]os.FileInfo, len(fileinfos)) fm := make(map[string]os.FileInfo, len(fileinfos))
@ -36,13 +36,13 @@ func dirList(c *fiber.Ctx, f http.File) error {
} }
basePathEscaped := html.EscapeString(c.Path()) basePathEscaped := html.EscapeString(c.Path())
fmt.Fprintf(c, "<html><head><title>%s</title><style>.dir { font-weight: bold }</style></head><body>", basePathEscaped) _, _ = fmt.Fprintf(c, "<html><head><title>%s</title><style>.dir { font-weight: bold }</style></head><body>", basePathEscaped)
fmt.Fprintf(c, "<h1>%s</h1>", basePathEscaped) _, _ = fmt.Fprintf(c, "<h1>%s</h1>", basePathEscaped)
fmt.Fprint(c, "<ul>") _, _ = fmt.Fprint(c, "<ul>")
if len(basePathEscaped) > 1 { if len(basePathEscaped) > 1 {
parentPathEscaped := html.EscapeString(utils.TrimRight(c.Path(), '/') + "/..") parentPathEscaped := html.EscapeString(utils.TrimRight(c.Path(), '/') + "/..")
fmt.Fprintf(c, `<li><a href="%s" class="dir">..</a></li>`, parentPathEscaped) _, _ = fmt.Fprintf(c, `<li><a href="%s" class="dir">..</a></li>`, parentPathEscaped)
} }
sort.Strings(filenames) sort.Strings(filenames)
@ -55,10 +55,10 @@ func dirList(c *fiber.Ctx, f http.File) error {
auxStr = fmt.Sprintf("file, %d bytes", fi.Size()) auxStr = fmt.Sprintf("file, %d bytes", fi.Size())
className = "file" className = "file"
} }
fmt.Fprintf(c, `<li><a href="%s" class="%s">%s</a>, %s, last modified %s</li>`, _, _ = fmt.Fprintf(c, `<li><a href="%s" class="%s">%s</a>, %s, last modified %s</li>`,
pathEscaped, className, html.EscapeString(name), auxStr, fi.ModTime()) pathEscaped, className, html.EscapeString(name), auxStr, fi.ModTime())
} }
fmt.Fprint(c, "</ul></body></html>") _, _ = fmt.Fprint(c, "</ul></body></html>")
c.Type("html") c.Type("html")

View File

@ -9,9 +9,7 @@ import (
"github.com/gofiber/fiber/v2/internal/storage/memory" "github.com/gofiber/fiber/v2/internal/storage/memory"
) )
var ( var ErrInvalidIdempotencyKey = errors.New("invalid idempotency key")
ErrInvalidIdempotencyKey = errors.New("invalid idempotency key")
)
// Config defines the config for middleware. // Config defines the config for middleware.
type Config struct { type Config struct {
@ -51,13 +49,15 @@ type Config struct {
} }
// ConfigDefault is the default config // ConfigDefault is the default config
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{ var ConfigDefault = Config{
Next: func(c *fiber.Ctx) bool { Next: func(c *fiber.Ctx) bool {
// Skip middleware if the request was done using a safe HTTP method // Skip middleware if the request was done using a safe HTTP method
return fiber.IsMethodSafe(c.Method()) return fiber.IsMethodSafe(c.Method())
}, },
Lifetime: 30 * time.Minute, Lifetime: 30 * time.Minute, //nolint:gomnd // No magic number, just the default config
KeyHeader: "X-Idempotency-Key", KeyHeader: "X-Idempotency-Key",
KeyHeaderValidate: func(k string) error { KeyHeaderValidate: func(k string) error {
@ -112,7 +112,7 @@ func configDefault(config ...Config) Config {
if cfg.Storage == nil { if cfg.Storage == nil {
cfg.Storage = memory.New(memory.Config{ cfg.Storage = memory.New(memory.Config{
GCInterval: cfg.Lifetime / 2, GCInterval: cfg.Lifetime / 2, //nolint:gomnd // Half the lifetime interval
}) })
} }

View File

@ -1,3 +1,4 @@
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
package idempotency_test package idempotency_test
import ( import (
@ -14,6 +15,7 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/idempotency" "github.com/gofiber/fiber/v2/middleware/idempotency"
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -172,5 +174,4 @@ func Benchmark_Idempotency(b *testing.B) {
h(c) h(c)
} }
}) })
} }

View File

@ -1,7 +1,7 @@
package limiter package limiter
import ( import (
"fmt" "log"
"time" "time"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
@ -58,19 +58,21 @@ type Config struct {
// Default: a new Fixed Window Rate Limiter // Default: a new Fixed Window Rate Limiter
LimiterMiddleware LimiterHandler LimiterMiddleware LimiterHandler
// DEPRECATED: Use Expiration instead // Deprecated: Use Expiration instead
Duration time.Duration Duration time.Duration
// DEPRECATED, use Storage instead // Deprecated: Use Storage instead
Store fiber.Storage Store fiber.Storage
// DEPRECATED, use KeyGenerator instead // Deprecated: Use KeyGenerator instead
Key func(*fiber.Ctx) string Key func(*fiber.Ctx) string
} }
// ConfigDefault is the default config // ConfigDefault is the default config
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{ var ConfigDefault = Config{
Max: 5, Max: 5, //nolint:gomnd // No magic number, just the default config
Expiration: 1 * time.Minute, Expiration: 1 * time.Minute,
KeyGenerator: func(c *fiber.Ctx) string { KeyGenerator: func(c *fiber.Ctx) string {
return c.IP() return c.IP()
@ -95,15 +97,15 @@ func configDefault(config ...Config) Config {
// Set default values // Set default values
if int(cfg.Duration.Seconds()) > 0 { if int(cfg.Duration.Seconds()) > 0 {
fmt.Println("[LIMITER] Duration is deprecated, please use Expiration") log.Printf("[LIMITER] Duration is deprecated, please use Expiration\n")
cfg.Expiration = cfg.Duration cfg.Expiration = cfg.Duration
} }
if cfg.Key != nil { if cfg.Key != nil {
fmt.Println("[LIMITER] Key is deprecated, please us KeyGenerator") log.Printf("[LIMITER] Key is deprecated, please us KeyGenerator\n")
cfg.KeyGenerator = cfg.Key cfg.KeyGenerator = cfg.Key
} }
if cfg.Store != nil { if cfg.Store != nil {
fmt.Println("[LIMITER] Store is deprecated, please use Storage") log.Printf("[LIMITER] Store is deprecated, please use Storage\n")
cfg.Storage = cfg.Store cfg.Storage = cfg.Store
} }
if cfg.Next == nil { if cfg.Next == nil {

View File

@ -2,7 +2,6 @@ package limiter
import ( import (
"io" "io"
"net/http"
"net/http/httptest" "net/http/httptest"
"sync" "sync"
"testing" "testing"
@ -11,6 +10,7 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/storage/memory" "github.com/gofiber/fiber/v2/internal/storage/memory"
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -34,7 +34,7 @@ func Test_Limiter_Concurrency_Store(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
singleRequest := func(wg *sync.WaitGroup) { singleRequest := func(wg *sync.WaitGroup) {
defer wg.Done() defer wg.Done()
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
@ -50,13 +50,13 @@ func Test_Limiter_Concurrency_Store(t *testing.T) {
wg.Wait() wg.Wait()
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode) utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/", nil)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode) utils.AssertEqual(t, 200, resp.StatusCode)
} }
@ -80,7 +80,7 @@ func Test_Limiter_Concurrency(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
singleRequest := func(wg *sync.WaitGroup) { singleRequest := func(wg *sync.WaitGroup) {
defer wg.Done() defer wg.Done()
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
@ -96,13 +96,13 @@ func Test_Limiter_Concurrency(t *testing.T) {
wg.Wait() wg.Wait()
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode) utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/", nil)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode) utils.AssertEqual(t, 200, resp.StatusCode)
} }
@ -120,21 +120,21 @@ func Test_Limiter_No_Skip_Choices(t *testing.T) {
})) }))
app.Get("/:status", func(c *fiber.Ctx) error { app.Get("/:status", func(c *fiber.Ctx) error {
if c.Params("status") == "fail" { if c.Params("status") == "fail" { //nolint:goconst // False positive
return c.SendStatus(400) return c.SendStatus(400)
} }
return c.SendStatus(200) return c.SendStatus(200)
}) })
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode) utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode) utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode) utils.AssertEqual(t, 429, resp.StatusCode)
} }
@ -157,21 +157,21 @@ func Test_Limiter_Skip_Failed_Requests(t *testing.T) {
return c.SendStatus(200) return c.SendStatus(200)
}) })
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode) utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode) utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode) utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode) utils.AssertEqual(t, 200, resp.StatusCode)
} }
@ -196,21 +196,21 @@ func Test_Limiter_Skip_Successful_Requests(t *testing.T) {
return c.SendStatus(200) return c.SendStatus(200)
}) })
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/success", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode) utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode) utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode) utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode) utils.AssertEqual(t, 400, resp.StatusCode)
} }
@ -232,7 +232,7 @@ func Benchmark_Limiter_Custom_Store(b *testing.B) {
h := app.Handler() h := app.Handler()
fctx := &fasthttp.RequestCtx{} fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod("GET") fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/") fctx.Request.SetRequestURI("/")
b.ResetTimer() b.ResetTimer()
@ -252,7 +252,7 @@ func Test_Limiter_Next(t *testing.T) {
}, },
})) }))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode) utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
} }
@ -271,7 +271,7 @@ func Test_Limiter_Headers(t *testing.T) {
}) })
fctx := &fasthttp.RequestCtx{} fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod("GET") fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/") fctx.Request.SetRequestURI("/")
app.Handler()(fctx) app.Handler()(fctx)
@ -301,7 +301,7 @@ func Benchmark_Limiter(b *testing.B) {
h := app.Handler() h := app.Handler()
fctx := &fasthttp.RequestCtx{} fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod("GET") fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/") fctx.Request.SetRequestURI("/")
b.ResetTimer() b.ResetTimer()
@ -327,7 +327,7 @@ func Test_Sliding_Window(t *testing.T) {
}) })
singleRequest := func(shouldFail bool) { singleRequest := func(shouldFail bool) {
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
if shouldFail { if shouldFail {
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode) utils.AssertEqual(t, 429, resp.StatusCode)

View File

@ -46,7 +46,7 @@ func newManager(storage fiber.Storage) *manager {
// acquire returns an *entry from the sync.Pool // acquire returns an *entry from the sync.Pool
func (m *manager) acquire() *item { func (m *manager) acquire() *item {
return m.pool.Get().(*item) return m.pool.Get().(*item) //nolint:forcetypeassert // We store nothing else in the pool
} }
// release and reset *entry to sync.Pool // release and reset *entry to sync.Pool
@ -58,37 +58,33 @@ func (m *manager) release(e *item) {
} }
// get data from storage or memory // get data from storage or memory
func (m *manager) get(key string) (it *item) { func (m *manager) get(key string) *item {
var it *item
if m.storage != nil { if m.storage != nil {
it = m.acquire() it = m.acquire()
if raw, _ := m.storage.Get(key); raw != nil { raw, err := m.storage.Get(key)
if err != nil {
return it
}
if raw != nil {
if _, err := it.UnmarshalMsg(raw); err != nil { if _, err := it.UnmarshalMsg(raw); err != nil {
return return it
} }
} }
return return it
} }
if it, _ = m.memory.Get(key).(*item); it == nil { if it, _ = m.memory.Get(key).(*item); it == nil { //nolint:errcheck // We store nothing else in the pool
it = m.acquire() it = m.acquire()
return it
} }
return return it
}
// get raw data from storage or memory
func (m *manager) getRaw(key string) (raw []byte) {
if m.storage != nil {
raw, _ = m.storage.Get(key)
} else {
raw, _ = m.memory.Get(key).([]byte)
}
return
} }
// set data to storage or memory // set data to storage or memory
func (m *manager) set(key string, it *item, exp time.Duration) { func (m *manager) set(key string, it *item, exp time.Duration) {
if m.storage != nil { if m.storage != nil {
if raw, err := it.MarshalMsg(nil); err == nil { if raw, err := it.MarshalMsg(nil); err == nil {
_ = m.storage.Set(key, raw, exp) _ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Handle error here
} }
// we can release data because it's serialized to database // we can release data because it's serialized to database
m.release(it) m.release(it)
@ -96,21 +92,3 @@ func (m *manager) set(key string, it *item, exp time.Duration) {
m.memory.Set(key, it, exp) m.memory.Set(key, it, exp)
} }
} }
// set data to storage or memory
func (m *manager) setRaw(key string, raw []byte, exp time.Duration) {
if m.storage != nil {
_ = m.storage.Set(key, raw, exp)
} else {
m.memory.Set(key, raw, exp)
}
}
// delete data from storage or memory
func (m *manager) delete(key string) {
if m.storage != nil {
_ = m.storage.Delete(key)
} else {
m.memory.Delete(key)
}
}

View File

@ -189,7 +189,7 @@ const (
TagBytesReceived = "bytesReceived" TagBytesReceived = "bytesReceived"
TagRoute = "route" TagRoute = "route"
TagError = "error" TagError = "error"
// DEPRECATED: Use TagReqHeader instead // Deprecated: Use TagReqHeader instead
TagHeader = "header:" // request header TagHeader = "header:" // request header
TagReqHeader = "reqHeader:" // request header TagReqHeader = "reqHeader:" // request header
TagRespHeader = "respHeader:" // response header TagRespHeader = "respHeader:" // response header

View File

@ -79,13 +79,15 @@ type Buffer interface {
type LogFunc func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) type LogFunc func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error)
// ConfigDefault is the default config // ConfigDefault is the default config
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{ var ConfigDefault = Config{
Next: nil, Next: nil,
Done: nil, Done: nil,
Format: "[${time}] ${status} - ${latency} ${method} ${path}\n", Format: "[${time}] ${status} - ${latency} ${method} ${path}\n",
TimeFormat: "15:04:05", TimeFormat: "15:04:05",
TimeZone: "Local", TimeZone: "Local",
TimeInterval: 500 * time.Millisecond, TimeInterval: 500 * time.Millisecond, //nolint:gomnd // No magic number, just the default config
Output: os.Stdout, Output: os.Stdout,
enableColors: true, enableColors: true,
} }

View File

@ -1,13 +1,10 @@
package logger package logger
import ( import (
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
) )
var DataPool = sync.Pool{New: func() interface{} { return new(Data) }}
// Data is a struct to define some variables to use in custom logger function. // Data is a struct to define some variables to use in custom logger function.
type Data struct { type Data struct {
Pid string Pid string

View File

@ -11,6 +11,7 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
"github.com/mattn/go-colorable" "github.com/mattn/go-colorable"
"github.com/mattn/go-isatty" "github.com/mattn/go-isatty"
"github.com/valyala/bytebufferpool" "github.com/valyala/bytebufferpool"
@ -55,6 +56,8 @@ func New(config ...Config) fiber.Handler {
once sync.Once once sync.Once
mu sync.Mutex mu sync.Mutex
errHandler fiber.ErrorHandler errHandler fiber.ErrorHandler
dataPool = sync.Pool{New: func() interface{} { return new(Data) }}
) )
// If colors are enabled, check terminal compatibility // If colors are enabled, check terminal compatibility
@ -75,7 +78,7 @@ func New(config ...Config) fiber.Handler {
} }
// Return new handler // Return new handler
return func(c *fiber.Ctx) (err error) { return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true // Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) { if cfg.Next != nil && cfg.Next(c) {
return c.Next() return c.Next()
@ -101,13 +104,13 @@ func New(config ...Config) fiber.Handler {
}) })
// Logger data // Logger data
data := DataPool.Get().(*Data) data := dataPool.Get().(*Data) //nolint:forcetypeassert,errcheck // We store nothing else in the pool
// no need for a reset, as long as we always override everything // no need for a reset, as long as we always override everything
data.Pid = pid data.Pid = pid
data.ErrPaddingStr = errPaddingStr data.ErrPaddingStr = errPaddingStr
data.Timestamp = timestamp data.Timestamp = timestamp
// put data back in the pool // put data back in the pool
defer DataPool.Put(data) defer dataPool.Put(data)
// Set latency start time // Set latency start time
if cfg.enableLatency { if cfg.enableLatency {
@ -121,7 +124,7 @@ func New(config ...Config) fiber.Handler {
// Manually call error handler // Manually call error handler
if chainErr != nil { if chainErr != nil {
if err := errHandler(c, chainErr); err != nil { if err := errHandler(c, chainErr); err != nil {
_ = c.SendStatus(fiber.StatusInternalServerError) _ = c.SendStatus(fiber.StatusInternalServerError) //nolint:errcheck // TODO: Explain why we ignore the error here
} }
} }
@ -142,7 +145,8 @@ func New(config ...Config) fiber.Handler {
} }
// Format log to buffer // Format log to buffer
_, _ = buf.WriteString(fmt.Sprintf("%s |%s %3d %s| %7v | %15s |%s %-7s %s| %-"+errPaddingStr+"s %s\n", _, _ = buf.WriteString( //nolint:errcheck // This will never fail
fmt.Sprintf("%s |%s %3d %s| %7v | %15s |%s %-7s %s| %-"+errPaddingStr+"s %s\n",
timestamp.Load().(string), timestamp.Load().(string),
statusColor(c.Response().StatusCode(), colors), c.Response().StatusCode(), colors.Reset, statusColor(c.Response().StatusCode(), colors), c.Response().StatusCode(), colors.Reset,
data.Stop.Sub(data.Start).Round(time.Millisecond), data.Stop.Sub(data.Start).Round(time.Millisecond),
@ -150,10 +154,11 @@ func New(config ...Config) fiber.Handler {
methodColor(c.Method(), colors), c.Method(), colors.Reset, methodColor(c.Method(), colors), c.Method(), colors.Reset,
c.Path(), c.Path(),
formatErr, formatErr,
)) ),
)
// Write buffer to output // Write buffer to output
_, _ = cfg.Output.Write(buf.Bytes()) _, _ = cfg.Output.Write(buf.Bytes()) //nolint:errcheck // This will never fail
if cfg.Done != nil { if cfg.Done != nil {
cfg.Done(c, buf.Bytes()) cfg.Done(c, buf.Bytes())
@ -169,7 +174,7 @@ func New(config ...Config) fiber.Handler {
// Loop over template parts execute dynamic parts and add fixed parts to the buffer // Loop over template parts execute dynamic parts and add fixed parts to the buffer
for i, logFunc := range logFunChain { for i, logFunc := range logFunChain {
if logFunc == nil { if logFunc == nil {
_, _ = buf.Write(templateChain[i]) _, _ = buf.Write(templateChain[i]) //nolint:errcheck // This will never fail
} else if templateChain[i] == nil { } else if templateChain[i] == nil {
_, err = logFunc(buf, c, data, "") _, err = logFunc(buf, c, data, "")
} else { } else {
@ -182,7 +187,7 @@ func New(config ...Config) fiber.Handler {
// Also write errors to the buffer // Also write errors to the buffer
if err != nil { if err != nil {
_, _ = buf.WriteString(err.Error()) _, _ = buf.WriteString(err.Error()) //nolint:errcheck // This will never fail
} }
mu.Lock() mu.Lock()
// Write buffer to output // Write buffer to output
@ -190,7 +195,7 @@ func New(config ...Config) fiber.Handler {
// Write error to output // Write error to output
if _, err := cfg.Output.Write([]byte(err.Error())); err != nil { if _, err := cfg.Output.Write([]byte(err.Error())); err != nil {
// There is something wrong with the given io.Writer // There is something wrong with the given io.Writer
fmt.Fprintf(os.Stderr, "Failed to write to log, %v\n", err) _, _ = fmt.Fprintf(os.Stderr, "Failed to write to log, %v\n", err)
} }
} }
mu.Unlock() mu.Unlock()

View File

@ -1,3 +1,4 @@
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
package logger package logger
import ( import (
@ -16,6 +17,7 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/requestid" "github.com/gofiber/fiber/v2/middleware/requestid"
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
"github.com/valyala/bytebufferpool" "github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -37,7 +39,7 @@ func Test_Logger(t *testing.T) {
return errors.New("some random error") return errors.New("some random error")
}) })
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode) utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
utils.AssertEqual(t, "some random error", buf.String()) utils.AssertEqual(t, "some random error", buf.String())
@ -70,21 +72,21 @@ func Test_Logger_locals(t *testing.T) {
return c.SendStatus(fiber.StatusOK) return c.SendStatus(fiber.StatusOK)
}) })
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "johndoe", buf.String()) utils.AssertEqual(t, "johndoe", buf.String())
buf.Reset() buf.Reset()
resp, err = app.Test(httptest.NewRequest("GET", "/int", nil)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/int", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "55", buf.String()) utils.AssertEqual(t, "55", buf.String())
buf.Reset() buf.Reset()
resp, err = app.Test(httptest.NewRequest("GET", "/empty", nil)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/empty", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "", buf.String()) utils.AssertEqual(t, "", buf.String())
@ -100,7 +102,7 @@ func Test_Logger_Next(t *testing.T) {
}, },
})) }))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode) utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
} }
@ -113,15 +115,15 @@ func Test_Logger_Done(t *testing.T) {
app.Use(New(Config{ app.Use(New(Config{
Done: func(c *fiber.Ctx, logString []byte) { Done: func(c *fiber.Ctx, logString []byte) {
if c.Response().StatusCode() == fiber.StatusOK { if c.Response().StatusCode() == fiber.StatusOK {
buf.Write(logString) _, err := buf.Write(logString)
utils.AssertEqual(t, nil, err)
} }
}, },
})).Get("/logging", func(ctx *fiber.Ctx) error { })).Get("/logging", func(ctx *fiber.Ctx) error {
return ctx.SendStatus(fiber.StatusOK) return ctx.SendStatus(fiber.StatusOK)
}) })
resp, err := app.Test(httptest.NewRequest("GET", "/logging", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/logging", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, true, buf.Len() > 0) utils.AssertEqual(t, true, buf.Len() > 0)
@ -135,7 +137,7 @@ func Test_Logger_ErrorTimeZone(t *testing.T) {
TimeZone: "invalid", TimeZone: "invalid",
})) }))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode) utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
} }
@ -156,7 +158,7 @@ func Test_Logger_ErrorOutput(t *testing.T) {
Output: o, Output: o,
})) }))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode) utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
@ -178,7 +180,7 @@ func Test_Logger_All(t *testing.T) {
// Alias colors // Alias colors
colors := app.Config().ColorScheme colors := app.Config().ColorScheme
resp, err := app.Test(httptest.NewRequest("GET", "/?foo=bar", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?foo=bar", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode) utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
@ -198,7 +200,7 @@ func Test_Query_Params(t *testing.T) {
Output: buf, Output: buf,
})) }))
resp, err := app.Test(httptest.NewRequest("GET", "/?foo=bar&baz=moz", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?foo=bar&baz=moz", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode) utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
@ -226,7 +228,7 @@ func Test_Response_Body(t *testing.T) {
return c.Send([]byte("Post in test")) return c.Send([]byte("Post in test"))
}) })
_, err := app.Test(httptest.NewRequest("GET", "/", nil)) _, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
expectedGetResponse := "Sample response body" expectedGetResponse := "Sample response body"
@ -234,7 +236,7 @@ func Test_Response_Body(t *testing.T) {
buf.Reset() // Reset buffer to test POST buf.Reset() // Reset buffer to test POST
_, err = app.Test(httptest.NewRequest("POST", "/test", nil)) _, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/test", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
expectedPostResponse := "Post in test" expectedPostResponse := "Post in test"
@ -258,7 +260,7 @@ func Test_Logger_AppendUint(t *testing.T) {
return c.SendString("hello") return c.SendString("hello")
}) })
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "0 5 200", buf.String()) utils.AssertEqual(t, "0 5 200", buf.String())
@ -285,12 +287,11 @@ func Test_Logger_Data_Race(t *testing.T) {
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
wg.Add(1) wg.Add(1)
go func() { go func() {
resp1, err1 = app.Test(httptest.NewRequest("GET", "/", nil)) resp1, err1 = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
wg.Done() wg.Done()
}() }()
resp2, err2 = app.Test(httptest.NewRequest("GET", "/", nil)) resp2, err2 = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
wg.Wait() wg.Wait()
utils.AssertEqual(t, nil, err1) utils.AssertEqual(t, nil, err1)
utils.AssertEqual(t, fiber.StatusOK, resp1.StatusCode) utils.AssertEqual(t, fiber.StatusOK, resp1.StatusCode)
utils.AssertEqual(t, nil, err2) utils.AssertEqual(t, nil, err2)
@ -299,21 +300,23 @@ func Test_Logger_Data_Race(t *testing.T) {
// go test -v -run=^$ -bench=Benchmark_Logger -benchmem -count=4 // go test -v -run=^$ -bench=Benchmark_Logger -benchmem -count=4
func Benchmark_Logger(b *testing.B) { func Benchmark_Logger(b *testing.B) {
benchSetup := func(bb *testing.B, app *fiber.App) { benchSetup := func(b *testing.B, app *fiber.App) {
b.Helper()
h := app.Handler() h := app.Handler()
fctx := &fasthttp.RequestCtx{} fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod("GET") fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/") fctx.Request.SetRequestURI("/")
bb.ReportAllocs() b.ReportAllocs()
bb.ResetTimer() b.ResetTimer()
for n := 0; n < bb.N; n++ { for n := 0; n < b.N; n++ {
h(fctx) h(fctx)
} }
utils.AssertEqual(bb, 200, fctx.Response.Header.StatusCode()) utils.AssertEqual(b, 200, fctx.Response.Header.StatusCode())
} }
b.Run("Base", func(bb *testing.B) { b.Run("Base", func(bb *testing.B) {
@ -375,8 +378,7 @@ func Test_Response_Header(t *testing.T) {
return c.SendString("Hello fiber!") return c.SendString("Hello fiber!")
}) })
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "Hello fiber!", buf.String()) utils.AssertEqual(t, "Hello fiber!", buf.String())
@ -396,10 +398,10 @@ func Test_Req_Header(t *testing.T) {
app.Get("/", func(c *fiber.Ctx) error { app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello fiber!") return c.SendString("Hello fiber!")
}) })
headerReq := httptest.NewRequest("GET", "/", nil) headerReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
headerReq.Header.Add("test", "Hello fiber!") headerReq.Header.Add("test", "Hello fiber!")
resp, err := app.Test(headerReq)
resp, err := app.Test(headerReq)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "Hello fiber!", buf.String()) utils.AssertEqual(t, "Hello fiber!", buf.String())
@ -419,10 +421,10 @@ func Test_ReqHeader_Header(t *testing.T) {
app.Get("/", func(c *fiber.Ctx) error { app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello fiber!") return c.SendString("Hello fiber!")
}) })
reqHeaderReq := httptest.NewRequest("GET", "/", nil) reqHeaderReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
reqHeaderReq.Header.Add("test", "Hello fiber!") reqHeaderReq.Header.Add("test", "Hello fiber!")
resp, err := app.Test(reqHeaderReq)
resp, err := app.Test(reqHeaderReq)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "Hello fiber!", buf.String()) utils.AssertEqual(t, "Hello fiber!", buf.String())
@ -449,10 +451,10 @@ func Test_CustomTags(t *testing.T) {
app.Get("/", func(c *fiber.Ctx) error { app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello fiber!") return c.SendString("Hello fiber!")
}) })
reqHeaderReq := httptest.NewRequest("GET", "/", nil) reqHeaderReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
reqHeaderReq.Header.Add("test", "Hello fiber!") reqHeaderReq.Header.Add("test", "Hello fiber!")
resp, err := app.Test(reqHeaderReq)
resp, err := app.Test(reqHeaderReq)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, customTag, buf.String()) utils.AssertEqual(t, customTag, buf.String())
@ -492,7 +494,7 @@ func Test_Logger_ByteSent_Streaming(t *testing.T) {
return nil return nil
}) })
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "0 0 200", buf.String()) utils.AssertEqual(t, "0 0 200", buf.String())

View File

@ -31,7 +31,7 @@ const (
TagBytesReceived = "bytesReceived" TagBytesReceived = "bytesReceived"
TagRoute = "route" TagRoute = "route"
TagError = "error" TagError = "error"
// DEPRECATED: Use TagReqHeader instead // Deprecated: Use TagReqHeader instead
TagHeader = "header:" TagHeader = "header:"
TagReqHeader = "reqHeader:" TagReqHeader = "reqHeader:"
TagRespHeader = "respHeader:" TagRespHeader = "respHeader:"
@ -195,7 +195,7 @@ func createTagMap(cfg *Config) map[string]LogFunc {
return output.WriteString(fmt.Sprintf("%7v", latency)) return output.WriteString(fmt.Sprintf("%7v", latency))
}, },
TagTime: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) { TagTime: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(data.Timestamp.Load().(string)) return output.WriteString(data.Timestamp.Load().(string)) //nolint:forcetypeassert // We always store a string in here
}, },
} }
// merge with custom tags from user // merge with custom tags from user

View File

@ -14,13 +14,16 @@ import (
// funcChain contains for the parts which exist the functions for the dynamic parts // funcChain contains for the parts which exist the functions for the dynamic parts
// funcChain and fixParts always have the same length and contain nil for the parts where no data is required in the chain, // funcChain and fixParts always have the same length and contain nil for the parts where no data is required in the chain,
// if a function exists for the part, a parameter for it can also exist in the fixParts slice // if a function exists for the part, a parameter for it can also exist in the fixParts slice
func buildLogFuncChain(cfg *Config, tagFunctions map[string]LogFunc) (fixParts [][]byte, funcChain []LogFunc, err error) { func buildLogFuncChain(cfg *Config, tagFunctions map[string]LogFunc) ([][]byte, []LogFunc, error) {
// process flow is copied from the fasttemplate flow https://github.com/valyala/fasttemplate/blob/2a2d1afadadf9715bfa19683cdaeac8347e5d9f9/template.go#L23-L62 // process flow is copied from the fasttemplate flow https://github.com/valyala/fasttemplate/blob/2a2d1afadadf9715bfa19683cdaeac8347e5d9f9/template.go#L23-L62
templateB := utils.UnsafeBytes(cfg.Format) templateB := utils.UnsafeBytes(cfg.Format)
startTagB := utils.UnsafeBytes(startTag) startTagB := utils.UnsafeBytes(startTag)
endTagB := utils.UnsafeBytes(endTag) endTagB := utils.UnsafeBytes(endTag)
paramSeparatorB := utils.UnsafeBytes(paramSeparator) paramSeparatorB := utils.UnsafeBytes(paramSeparator)
var fixParts [][]byte
var funcChain []LogFunc
for { for {
currentPos := bytes.Index(templateB, startTagB) currentPos := bytes.Index(templateB, startTagB)
if currentPos < 0 { if currentPos < 0 {
@ -42,13 +45,13 @@ func buildLogFuncChain(cfg *Config, tagFunctions map[string]LogFunc) (fixParts [
// ## function block ## // ## function block ##
// first check for tags with parameters // first check for tags with parameters
if index := bytes.Index(templateB[:currentPos], paramSeparatorB); index != -1 { if index := bytes.Index(templateB[:currentPos], paramSeparatorB); index != -1 {
if logFunc, ok := tagFunctions[utils.UnsafeString(templateB[:index+1])]; ok { logFunc, ok := tagFunctions[utils.UnsafeString(templateB[:index+1])]
if !ok {
return nil, nil, errors.New("No parameter found in \"" + utils.UnsafeString(templateB[:currentPos]) + "\"")
}
funcChain = append(funcChain, logFunc) funcChain = append(funcChain, logFunc)
// add param to the fixParts // add param to the fixParts
fixParts = append(fixParts, templateB[index+1:currentPos]) fixParts = append(fixParts, templateB[index+1:currentPos])
} else {
return nil, nil, errors.New("No parameter found in \"" + utils.UnsafeString(templateB[:currentPos]) + "\"")
}
} else if logFunc, ok := tagFunctions[utils.UnsafeString(templateB[:currentPos])]; ok { } else if logFunc, ok := tagFunctions[utils.UnsafeString(templateB[:currentPos])]; ok {
// add functions without parameter // add functions without parameter
funcChain = append(funcChain, logFunc) funcChain = append(funcChain, logFunc)
@ -63,5 +66,5 @@ func buildLogFuncChain(cfg *Config, tagFunctions map[string]LogFunc) (fixParts [
funcChain = append(funcChain, nil) funcChain = append(funcChain, nil)
fixParts = append(fixParts, templateB) fixParts = append(fixParts, templateB)
return return fixParts, funcChain, nil
} }

View File

@ -41,36 +41,48 @@ type Config struct {
// ChartJsURL for specify ChartJS library path or URL . also you can use relative path // ChartJsURL for specify ChartJS library path or URL . also you can use relative path
// //
// Optional. Default: https://cdn.jsdelivr.net/npm/chart.js@2.9/dist/Chart.bundle.min.js // Optional. Default: https://cdn.jsdelivr.net/npm/chart.js@2.9/dist/Chart.bundle.min.js
ChartJsURL string ChartJSURL string
index string index string
} }
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{ var ConfigDefault = Config{
Title: defaultTitle, Title: defaultTitle,
Refresh: defaultRefresh, Refresh: defaultRefresh,
FontURL: defaultFontURL, FontURL: defaultFontURL,
ChartJsURL: defaultChartJsURL, ChartJSURL: defaultChartJSURL,
CustomHead: defaultCustomHead, CustomHead: defaultCustomHead,
APIOnly: false, APIOnly: false,
Next: nil, Next: nil,
index: newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJsURL, index: newIndex(viewBag{
defaultCustomHead}), defaultTitle,
defaultRefresh,
defaultFontURL,
defaultChartJSURL,
defaultCustomHead,
}),
} }
func configDefault(config ...Config) Config { func configDefault(config ...Config) Config {
// Users can change ConfigDefault.Title/Refresh which then // Users can change ConfigDefault.Title/Refresh which then
// become incompatible with ConfigDefault.index // become incompatible with ConfigDefault.index
if ConfigDefault.Title != defaultTitle || ConfigDefault.Refresh != defaultRefresh || if ConfigDefault.Title != defaultTitle ||
ConfigDefault.FontURL != defaultFontURL || ConfigDefault.ChartJsURL != defaultChartJsURL || ConfigDefault.Refresh != defaultRefresh ||
ConfigDefault.FontURL != defaultFontURL ||
ConfigDefault.ChartJSURL != defaultChartJSURL ||
ConfigDefault.CustomHead != defaultCustomHead { ConfigDefault.CustomHead != defaultCustomHead {
if ConfigDefault.Refresh < minRefresh { if ConfigDefault.Refresh < minRefresh {
ConfigDefault.Refresh = minRefresh ConfigDefault.Refresh = minRefresh
} }
// update default index with new default title/refresh // update default index with new default title/refresh
ConfigDefault.index = newIndex(viewBag{ConfigDefault.Title, ConfigDefault.index = newIndex(viewBag{
ConfigDefault.Refresh, ConfigDefault.FontURL, ConfigDefault.ChartJsURL, ConfigDefault.CustomHead}) ConfigDefault.Title,
ConfigDefault.Refresh,
ConfigDefault.FontURL,
ConfigDefault.ChartJSURL,
ConfigDefault.CustomHead,
})
} }
// Return default config if nothing provided // Return default config if nothing provided
@ -93,8 +105,8 @@ func configDefault(config ...Config) Config {
cfg.FontURL = defaultFontURL cfg.FontURL = defaultFontURL
} }
if cfg.ChartJsURL == "" { if cfg.ChartJSURL == "" {
cfg.ChartJsURL = defaultChartJsURL cfg.ChartJSURL = defaultChartJSURL
} }
if cfg.Refresh < minRefresh { if cfg.Refresh < minRefresh {
cfg.Refresh = minRefresh cfg.Refresh = minRefresh
@ -112,8 +124,8 @@ func configDefault(config ...Config) Config {
cfg.index = newIndex(viewBag{ cfg.index = newIndex(viewBag{
title: cfg.Title, title: cfg.Title,
refresh: cfg.Refresh, refresh: cfg.Refresh,
fontUrl: cfg.FontURL, fontURL: cfg.FontURL,
chartJsUrl: cfg.ChartJsURL, chartJSURL: cfg.ChartJSURL,
customHead: cfg.CustomHead, customHead: cfg.CustomHead,
}) })

View File

@ -18,11 +18,11 @@ func Test_Config_Default(t *testing.T) {
utils.AssertEqual(t, defaultTitle, cfg.Title) utils.AssertEqual(t, defaultTitle, cfg.Title)
utils.AssertEqual(t, defaultRefresh, cfg.Refresh) utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
utils.AssertEqual(t, defaultFontURL, cfg.FontURL) utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL) utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead) utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
utils.AssertEqual(t, false, cfg.APIOnly) utils.AssertEqual(t, false, cfg.APIOnly)
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next) utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJsURL, defaultCustomHead}), cfg.index) utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
}) })
t.Run("set title", func(t *testing.T) { t.Run("set title", func(t *testing.T) {
@ -35,11 +35,11 @@ func Test_Config_Default(t *testing.T) {
utils.AssertEqual(t, title, cfg.Title) utils.AssertEqual(t, title, cfg.Title)
utils.AssertEqual(t, defaultRefresh, cfg.Refresh) utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
utils.AssertEqual(t, defaultFontURL, cfg.FontURL) utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL) utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead) utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
utils.AssertEqual(t, false, cfg.APIOnly) utils.AssertEqual(t, false, cfg.APIOnly)
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next) utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
utils.AssertEqual(t, newIndex(viewBag{title, defaultRefresh, defaultFontURL, defaultChartJsURL, defaultCustomHead}), cfg.index) utils.AssertEqual(t, newIndex(viewBag{title, defaultRefresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
}) })
t.Run("set refresh less than default", func(t *testing.T) { t.Run("set refresh less than default", func(t *testing.T) {
@ -51,11 +51,11 @@ func Test_Config_Default(t *testing.T) {
utils.AssertEqual(t, defaultTitle, cfg.Title) utils.AssertEqual(t, defaultTitle, cfg.Title)
utils.AssertEqual(t, minRefresh, cfg.Refresh) utils.AssertEqual(t, minRefresh, cfg.Refresh)
utils.AssertEqual(t, defaultFontURL, cfg.FontURL) utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL) utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead) utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
utils.AssertEqual(t, false, cfg.APIOnly) utils.AssertEqual(t, false, cfg.APIOnly)
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next) utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, minRefresh, defaultFontURL, defaultChartJsURL, defaultCustomHead}), cfg.index) utils.AssertEqual(t, newIndex(viewBag{defaultTitle, minRefresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
}) })
t.Run("set refresh", func(t *testing.T) { t.Run("set refresh", func(t *testing.T) {
@ -68,45 +68,45 @@ func Test_Config_Default(t *testing.T) {
utils.AssertEqual(t, defaultTitle, cfg.Title) utils.AssertEqual(t, defaultTitle, cfg.Title)
utils.AssertEqual(t, refresh, cfg.Refresh) utils.AssertEqual(t, refresh, cfg.Refresh)
utils.AssertEqual(t, defaultFontURL, cfg.FontURL) utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL) utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead) utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
utils.AssertEqual(t, false, cfg.APIOnly) utils.AssertEqual(t, false, cfg.APIOnly)
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next) utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, refresh, defaultFontURL, defaultChartJsURL, defaultCustomHead}), cfg.index) utils.AssertEqual(t, newIndex(viewBag{defaultTitle, refresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
}) })
t.Run("set font url", func(t *testing.T) { t.Run("set font url", func(t *testing.T) {
t.Parallel() t.Parallel()
fontUrl := "https://example.com" fontURL := "https://example.com"
cfg := configDefault(Config{ cfg := configDefault(Config{
FontURL: fontUrl, FontURL: fontURL,
}) })
utils.AssertEqual(t, defaultTitle, cfg.Title) utils.AssertEqual(t, defaultTitle, cfg.Title)
utils.AssertEqual(t, defaultRefresh, cfg.Refresh) utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
utils.AssertEqual(t, fontUrl, cfg.FontURL) utils.AssertEqual(t, fontURL, cfg.FontURL)
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL) utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead) utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
utils.AssertEqual(t, false, cfg.APIOnly) utils.AssertEqual(t, false, cfg.APIOnly)
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next) utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, fontUrl, defaultChartJsURL, defaultCustomHead}), cfg.index) utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, fontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
}) })
t.Run("set chart js url", func(t *testing.T) { t.Run("set chart js url", func(t *testing.T) {
t.Parallel() t.Parallel()
chartUrl := "http://example.com" chartURL := "http://example.com"
cfg := configDefault(Config{ cfg := configDefault(Config{
ChartJsURL: chartUrl, ChartJSURL: chartURL,
}) })
utils.AssertEqual(t, defaultTitle, cfg.Title) utils.AssertEqual(t, defaultTitle, cfg.Title)
utils.AssertEqual(t, defaultRefresh, cfg.Refresh) utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
utils.AssertEqual(t, defaultFontURL, cfg.FontURL) utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
utils.AssertEqual(t, chartUrl, cfg.ChartJsURL) utils.AssertEqual(t, chartURL, cfg.ChartJSURL)
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead) utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
utils.AssertEqual(t, false, cfg.APIOnly) utils.AssertEqual(t, false, cfg.APIOnly)
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next) utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, chartUrl, defaultCustomHead}), cfg.index) utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, chartURL, defaultCustomHead}), cfg.index)
}) })
t.Run("set custom head", func(t *testing.T) { t.Run("set custom head", func(t *testing.T) {
@ -119,11 +119,11 @@ func Test_Config_Default(t *testing.T) {
utils.AssertEqual(t, defaultTitle, cfg.Title) utils.AssertEqual(t, defaultTitle, cfg.Title)
utils.AssertEqual(t, defaultRefresh, cfg.Refresh) utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
utils.AssertEqual(t, defaultFontURL, cfg.FontURL) utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL) utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
utils.AssertEqual(t, head, cfg.CustomHead) utils.AssertEqual(t, head, cfg.CustomHead)
utils.AssertEqual(t, false, cfg.APIOnly) utils.AssertEqual(t, false, cfg.APIOnly)
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next) utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJsURL, head}), cfg.index) utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJSURL, head}), cfg.index)
}) })
t.Run("set api only", func(t *testing.T) { t.Run("set api only", func(t *testing.T) {
@ -135,11 +135,11 @@ func Test_Config_Default(t *testing.T) {
utils.AssertEqual(t, defaultTitle, cfg.Title) utils.AssertEqual(t, defaultTitle, cfg.Title)
utils.AssertEqual(t, defaultRefresh, cfg.Refresh) utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
utils.AssertEqual(t, defaultFontURL, cfg.FontURL) utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL) utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead) utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
utils.AssertEqual(t, true, cfg.APIOnly) utils.AssertEqual(t, true, cfg.APIOnly)
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next) utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJsURL, defaultCustomHead}), cfg.index) utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
}) })
t.Run("set next", func(t *testing.T) { t.Run("set next", func(t *testing.T) {
@ -154,10 +154,10 @@ func Test_Config_Default(t *testing.T) {
utils.AssertEqual(t, defaultTitle, cfg.Title) utils.AssertEqual(t, defaultTitle, cfg.Title)
utils.AssertEqual(t, defaultRefresh, cfg.Refresh) utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
utils.AssertEqual(t, defaultFontURL, cfg.FontURL) utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL) utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead) utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
utils.AssertEqual(t, false, cfg.APIOnly) utils.AssertEqual(t, false, cfg.APIOnly)
utils.AssertEqual(t, f(nil), cfg.Next(nil)) utils.AssertEqual(t, f(nil), cfg.Next(nil))
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJsURL, defaultCustomHead}), cfg.index) utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
}) })
} }

View File

@ -9,23 +9,22 @@ import (
type viewBag struct { type viewBag struct {
title string title string
refresh time.Duration refresh time.Duration
fontUrl string fontURL string
chartJsUrl string chartJSURL string
customHead string customHead string
} }
// returns index with new title/refresh // returns index with new title/refresh
func newIndex(dat viewBag) string { func newIndex(dat viewBag) string {
timeout := dat.refresh.Milliseconds() - timeoutDiff timeout := dat.refresh.Milliseconds() - timeoutDiff
if timeout < timeoutDiff { if timeout < timeoutDiff {
timeout = timeoutDiff timeout = timeoutDiff
} }
ts := strconv.FormatInt(timeout, 10) ts := strconv.FormatInt(timeout, 10)
replacer := strings.NewReplacer("$TITLE", dat.title, "$TIMEOUT", ts, replacer := strings.NewReplacer("$TITLE", dat.title, "$TIMEOUT", ts,
"$FONT_URL", dat.fontUrl, "$CHART_JS_URL", dat.chartJsUrl, "$CUSTOM_HEAD", dat.customHead, "$FONT_URL", dat.fontURL, "$CHART_JS_URL", dat.chartJSURL, "$CUSTOM_HEAD", dat.customHead,
) )
return replacer.Replace(indexHtml) return replacer.Replace(indexHTML)
} }
const ( const (
@ -35,11 +34,11 @@ const (
timeoutDiff = 200 // timeout will be Refresh (in milliseconds) - timeoutDiff timeoutDiff = 200 // timeout will be Refresh (in milliseconds) - timeoutDiff
minRefresh = timeoutDiff * time.Millisecond minRefresh = timeoutDiff * time.Millisecond
defaultFontURL = `https://fonts.googleapis.com/css2?family=Roboto:wght@400;900&display=swap` defaultFontURL = `https://fonts.googleapis.com/css2?family=Roboto:wght@400;900&display=swap`
defaultChartJsURL = `https://cdn.jsdelivr.net/npm/chart.js@2.9/dist/Chart.bundle.min.js` defaultChartJSURL = `https://cdn.jsdelivr.net/npm/chart.js@2.9/dist/Chart.bundle.min.js`
defaultCustomHead = `` defaultCustomHead = ``
// parametrized by $TITLE and $TIMEOUT // parametrized by $TITLE and $TIMEOUT
indexHtml = `<!DOCTYPE html> indexHTML = `<!DOCTYPE html>
<html lang="en"> <html lang="en">
<head> <head>
<meta charset="UTF-8"> <meta charset="UTF-8">

View File

@ -33,18 +33,20 @@ type statsOS struct {
Conns int `json:"conns"` Conns int `json:"conns"`
} }
//nolint:gochecknoglobals // TODO: Do not use a global var here
var ( var (
monitPidCpu atomic.Value monitPIDCPU atomic.Value
monitPidRam atomic.Value monitPIDRAM atomic.Value
monitPidConns atomic.Value monitPIDConns atomic.Value
monitOsCpu atomic.Value monitOSCPU atomic.Value
monitOsRam atomic.Value monitOSRAM atomic.Value
monitOsTotalRam atomic.Value monitOSTotalRAM atomic.Value
monitOsLoadAvg atomic.Value monitOSLoadAvg atomic.Value
monitOsConns atomic.Value monitOSConns atomic.Value
) )
//nolint:gochecknoglobals // TODO: Do not use a global var here
var ( var (
mutex sync.RWMutex mutex sync.RWMutex
once sync.Once once sync.Once
@ -58,7 +60,7 @@ func New(config ...Config) fiber.Handler {
// Start routine to update statistics // Start routine to update statistics
once.Do(func() { once.Do(func() {
p, _ := process.NewProcess(int32(os.Getpid())) p, _ := process.NewProcess(int32(os.Getpid())) //nolint:errcheck // TODO: Handle error
updateStatistics(p) updateStatistics(p)
@ -72,6 +74,7 @@ func New(config ...Config) fiber.Handler {
}) })
// Return new handler // Return new handler
//nolint:errcheck // Ignore the type-assertion errors
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true // Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) { if cfg.Next != nil && cfg.Next(c) {
@ -83,15 +86,15 @@ func New(config ...Config) fiber.Handler {
} }
if c.Get(fiber.HeaderAccept) == fiber.MIMEApplicationJSON || cfg.APIOnly { if c.Get(fiber.HeaderAccept) == fiber.MIMEApplicationJSON || cfg.APIOnly {
mutex.Lock() mutex.Lock()
data.PID.CPU = monitPidCpu.Load().(float64) data.PID.CPU, _ = monitPIDCPU.Load().(float64)
data.PID.RAM = monitPidRam.Load().(uint64) data.PID.RAM, _ = monitPIDRAM.Load().(uint64)
data.PID.Conns = monitPidConns.Load().(int) data.PID.Conns, _ = monitPIDConns.Load().(int)
data.OS.CPU = monitOsCpu.Load().(float64) data.OS.CPU, _ = monitOSCPU.Load().(float64)
data.OS.RAM = monitOsRam.Load().(uint64) data.OS.RAM, _ = monitOSRAM.Load().(uint64)
data.OS.TotalRAM = monitOsTotalRam.Load().(uint64) data.OS.TotalRAM, _ = monitOSTotalRAM.Load().(uint64)
data.OS.LoadAvg = monitOsLoadAvg.Load().(float64) data.OS.LoadAvg, _ = monitOSLoadAvg.Load().(float64)
data.OS.Conns = monitOsConns.Load().(int) data.OS.Conns, _ = monitOSConns.Load().(int)
mutex.Unlock() mutex.Unlock()
return c.Status(fiber.StatusOK).JSON(data) return c.Status(fiber.StatusOK).JSON(data)
} }
@ -101,29 +104,35 @@ func New(config ...Config) fiber.Handler {
} }
func updateStatistics(p *process.Process) { func updateStatistics(p *process.Process) {
pidCpu, _ := p.CPUPercent() pidCPU, err := p.CPUPercent()
monitPidCpu.Store(pidCpu / 10) if err != nil {
monitPIDCPU.Store(pidCPU / 10) //nolint:gomnd // TODO: Explain why we divide by 10 here
if osCpu, _ := cpu.Percent(0, false); len(osCpu) > 0 {
monitOsCpu.Store(osCpu[0])
} }
if pidMem, _ := p.MemoryInfo(); pidMem != nil { if osCPU, err := cpu.Percent(0, false); err != nil && len(osCPU) > 0 {
monitPidRam.Store(pidMem.RSS) monitOSCPU.Store(osCPU[0])
} }
if osMem, _ := mem.VirtualMemory(); osMem != nil { if pidRAM, err := p.MemoryInfo(); err != nil && pidRAM != nil {
monitOsRam.Store(osMem.Used) monitPIDRAM.Store(pidRAM.RSS)
monitOsTotalRam.Store(osMem.Total)
} }
if loadAvg, _ := load.Avg(); loadAvg != nil { if osRAM, err := mem.VirtualMemory(); err != nil && osRAM != nil {
monitOsLoadAvg.Store(loadAvg.Load1) monitOSRAM.Store(osRAM.Used)
monitOSTotalRAM.Store(osRAM.Total)
} }
pidConns, _ := net.ConnectionsPid("tcp", p.Pid) if loadAvg, err := load.Avg(); err != nil && loadAvg != nil {
monitPidConns.Store(len(pidConns)) monitOSLoadAvg.Store(loadAvg.Load1)
}
osConns, _ := net.Connections("tcp")
monitOsConns.Store(len(osConns)) pidConns, err := net.ConnectionsPid("tcp", p.Pid)
if err != nil {
monitPIDConns.Store(len(pidConns))
}
osConns, err := net.Connections("tcp")
if err != nil {
monitOSConns.Store(len(osConns))
}
} }

View File

@ -10,6 +10,7 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -61,6 +62,7 @@ func Test_Monitor_Html(t *testing.T) {
conf.Refresh.Milliseconds()-timeoutDiff) conf.Refresh.Milliseconds()-timeoutDiff)
utils.AssertEqual(t, true, bytes.Contains(buf, []byte(timeoutLine))) utils.AssertEqual(t, true, bytes.Contains(buf, []byte(timeoutLine)))
} }
func Test_Monitor_Html_CustomCodes(t *testing.T) { func Test_Monitor_Html_CustomCodes(t *testing.T) {
t.Parallel() t.Parallel()
@ -82,8 +84,10 @@ func Test_Monitor_Html_CustomCodes(t *testing.T) {
utils.AssertEqual(t, true, bytes.Contains(buf, []byte(timeoutLine))) utils.AssertEqual(t, true, bytes.Contains(buf, []byte(timeoutLine)))
// custom config // custom config
conf := Config{Title: "New " + defaultTitle, Refresh: defaultRefresh + time.Second, conf := Config{
ChartJsURL: "https://cdnjs.com/libraries/Chart.js", Title: "New " + defaultTitle,
Refresh: defaultRefresh + time.Second,
ChartJSURL: "https://cdnjs.com/libraries/Chart.js",
FontURL: "/public/my-font.css", FontURL: "/public/my-font.css",
CustomHead: `<style>body{background:#fff}</style>`, CustomHead: `<style>body{background:#fff}</style>`,
} }
@ -136,7 +140,7 @@ func Benchmark_Monitor(b *testing.B) {
h := app.Handler() h := app.Handler()
fctx := &fasthttp.RequestCtx{} fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod("GET") fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/") fctx.Request.SetRequestURI("/")
fctx.Request.Header.Set(fiber.HeaderAccept, fiber.MIMEApplicationJSON) fctx.Request.Header.Set(fiber.HeaderAccept, fiber.MIMEApplicationJSON)

View File

@ -1,6 +1,8 @@
package pprof package pprof
import "github.com/gofiber/fiber/v2" import (
"github.com/gofiber/fiber/v2"
)
// Config defines the config for middleware. // Config defines the config for middleware.
type Config struct { type Config struct {
@ -17,6 +19,7 @@ type Config struct {
Prefix string Prefix string
} }
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{ var ConfigDefault = Config{
Next: nil, Next: nil,
} }

View File

@ -5,9 +5,15 @@ import (
"strings" "strings"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/valyala/fasthttp/fasthttpadaptor" "github.com/valyala/fasthttp/fasthttpadaptor"
) )
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Set pprof adaptors // Set pprof adaptors
var ( var (
pprofIndex = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Index) pprofIndex = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Index)
@ -23,11 +29,6 @@ var (
pprofThreadcreate = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("threadcreate").ServeHTTP) pprofThreadcreate = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("threadcreate").ServeHTTP)
) )
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Return new handler // Return new handler
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true // Don't execute middleware if Next returns true

View File

@ -5,6 +5,7 @@ import (
"time" "time"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -48,7 +49,7 @@ type Config struct {
WriteBufferSize int WriteBufferSize int
// tls config for the http client. // tls config for the http client.
TlsConfig *tls.Config TlsConfig *tls.Config //nolint:stylecheck,revive // TODO: Rename to "TLSConfig" in v3
// Client is custom client when client config is complex. // Client is custom client when client config is complex.
// Note that Servers, Timeout, WriteBufferSize, ReadBufferSize and TlsConfig // Note that Servers, Timeout, WriteBufferSize, ReadBufferSize and TlsConfig
@ -57,6 +58,8 @@ type Config struct {
} }
// ConfigDefault is the default config // ConfigDefault is the default config
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{ var ConfigDefault = Config{
Next: nil, Next: nil,
ModifyRequest: nil, ModifyRequest: nil,

View File

@ -3,19 +3,20 @@ package proxy
import ( import (
"bytes" "bytes"
"crypto/tls" "crypto/tls"
"fmt" "log"
"net/url" "net/url"
"strings" "strings"
"sync" "sync"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
// New is deprecated // New is deprecated
func New(config Config) fiber.Handler { func New(config Config) fiber.Handler {
fmt.Println("proxy.New is deprecated, please use proxy.Balancer instead") log.Printf("proxy.New is deprecated, please use proxy.Balancer instead\n")
return Balancer(config) return Balancer(config)
} }
@ -25,7 +26,7 @@ func Balancer(config Config) fiber.Handler {
cfg := configDefault(config) cfg := configDefault(config)
// Load balanced client // Load balanced client
var lbc = &fasthttp.LBClient{} lbc := &fasthttp.LBClient{}
// Note that Servers, Timeout, WriteBufferSize, ReadBufferSize and TlsConfig // Note that Servers, Timeout, WriteBufferSize, ReadBufferSize and TlsConfig
// will not be used if the client are set. // will not be used if the client are set.
if config.Client == nil { if config.Client == nil {
@ -61,7 +62,7 @@ func Balancer(config Config) fiber.Handler {
} }
// Return new handler // Return new handler
return func(c *fiber.Ctx) (err error) { return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true // Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) { if cfg.Next != nil && cfg.Next(c) {
return c.Next() return c.Next()
@ -76,7 +77,7 @@ func Balancer(config Config) fiber.Handler {
// Modify request // Modify request
if cfg.ModifyRequest != nil { if cfg.ModifyRequest != nil {
if err = cfg.ModifyRequest(c); err != nil { if err := cfg.ModifyRequest(c); err != nil {
return err return err
} }
} }
@ -84,7 +85,7 @@ func Balancer(config Config) fiber.Handler {
req.SetRequestURI(utils.UnsafeString(req.RequestURI())) req.SetRequestURI(utils.UnsafeString(req.RequestURI()))
// Forward request // Forward request
if err = lbc.Do(req, res); err != nil { if err := lbc.Do(req, res); err != nil {
return err return err
} }
@ -93,7 +94,7 @@ func Balancer(config Config) fiber.Handler {
// Modify response // Modify response
if cfg.ModifyResponse != nil { if cfg.ModifyResponse != nil {
if err = cfg.ModifyResponse(c); err != nil { if err := cfg.ModifyResponse(c); err != nil {
return err return err
} }
} }
@ -103,16 +104,20 @@ func Balancer(config Config) fiber.Handler {
} }
} }
//nolint:gochecknoglobals // TODO: Do not use a global var here
var client = &fasthttp.Client{ var client = &fasthttp.Client{
NoDefaultUserAgentHeader: true, NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true, DisablePathNormalizing: true,
} }
//nolint:gochecknoglobals // TODO: Do not use a global var here
var lock sync.RWMutex var lock sync.RWMutex
// WithTlsConfig update http client with a user specified tls.config // WithTlsConfig update http client with a user specified tls.config
// This function should be called before Do and Forward. // This function should be called before Do and Forward.
// Deprecated: use WithClient instead. // Deprecated: use WithClient instead.
//
//nolint:stylecheck,revive // TODO: Rename to "WithTLSConfig" in v3
func WithTlsConfig(tlsConfig *tls.Config) { func WithTlsConfig(tlsConfig *tls.Config) {
client.TLSConfig = tlsConfig client.TLSConfig = tlsConfig
} }

View File

@ -4,7 +4,6 @@ import (
"crypto/tls" "crypto/tls"
"io" "io"
"net" "net"
"net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
"testing" "testing"
@ -13,10 +12,11 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/tlstest" "github.com/gofiber/fiber/v2/internal/tlstest"
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
func createProxyTestServer(handler fiber.Handler, t *testing.T) (*fiber.App, string) { func createProxyTestServer(t *testing.T, handler fiber.Handler) (*fiber.App, string) {
t.Helper() t.Helper()
target := fiber.New(fiber.Config{DisableStartupMessage: true}) target := fiber.New(fiber.Config{DisableStartupMessage: true})
@ -60,7 +60,7 @@ func Test_Proxy_Next(t *testing.T) {
}, },
})) }))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode) utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
} }
@ -69,11 +69,11 @@ func Test_Proxy_Next(t *testing.T) {
func Test_Proxy(t *testing.T) { func Test_Proxy(t *testing.T) {
t.Parallel() t.Parallel()
target, addr := createProxyTestServer( target, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }, t, return c.SendStatus(fiber.StatusTeapot)
) })
resp, err := target.Test(httptest.NewRequest("GET", "/", nil), 2000) resp, err := target.Test(httptest.NewRequest(fiber.MethodGet, "/", nil), 2000)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode) utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
@ -81,7 +81,7 @@ func Test_Proxy(t *testing.T) {
app.Use(Balancer(Config{Servers: []string{addr}})) app.Use(Balancer(Config{Servers: []string{addr}}))
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Host = addr req.Host = addr
resp, err = app.Test(req) resp, err = app.Test(req)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
@ -107,7 +107,7 @@ func Test_Proxy_Balancer_WithTlsConfig(t *testing.T) {
}) })
addr := ln.Addr().String() addr := ln.Addr().String()
clientTLSConf := &tls.Config{InsecureSkipVerify: true} clientTLSConf := &tls.Config{InsecureSkipVerify: true} //nolint:gosec // We're in a test func, so this is fine
// disable certificate verification in Balancer // disable certificate verification in Balancer
app.Use(Balancer(Config{ app.Use(Balancer(Config{
@ -128,9 +128,9 @@ func Test_Proxy_Balancer_WithTlsConfig(t *testing.T) {
func Test_Proxy_Forward_WithTlsConfig_To_Http(t *testing.T) { func Test_Proxy_Forward_WithTlsConfig_To_Http(t *testing.T) {
t.Parallel() t.Parallel()
_, targetAddr := createProxyTestServer(func(c *fiber.Ctx) error { _, targetAddr := createProxyTestServer(t, func(c *fiber.Ctx) error {
return c.SendString("hello from target") return c.SendString("hello from target")
}, t) })
proxyServerTLSConf, _, err := tlstest.GetTLSConfigs() proxyServerTLSConf, _, err := tlstest.GetTLSConfigs()
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
@ -164,13 +164,13 @@ func Test_Proxy_Forward(t *testing.T) {
app := fiber.New() app := fiber.New()
_, addr := createProxyTestServer( _, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
func(c *fiber.Ctx) error { return c.SendString("forwarded") }, t, return c.SendString("forwarded")
) })
app.Use(Forward("http://" + addr)) app.Use(Forward("http://" + addr))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
@ -198,7 +198,7 @@ func Test_Proxy_Forward_WithTlsConfig(t *testing.T) {
}) })
addr := ln.Addr().String() addr := ln.Addr().String()
clientTLSConf := &tls.Config{InsecureSkipVerify: true} clientTLSConf := &tls.Config{InsecureSkipVerify: true} //nolint:gosec // We're in a test func, so this is fine
// disable certificate verification // disable certificate verification
WithTlsConfig(clientTLSConf) WithTlsConfig(clientTLSConf)
@ -217,9 +217,9 @@ func Test_Proxy_Forward_WithTlsConfig(t *testing.T) {
func Test_Proxy_Modify_Response(t *testing.T) { func Test_Proxy_Modify_Response(t *testing.T) {
t.Parallel() t.Parallel()
_, addr := createProxyTestServer(func(c *fiber.Ctx) error { _, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
return c.Status(500).SendString("not modified") return c.Status(500).SendString("not modified")
}, t) })
app := fiber.New() app := fiber.New()
app.Use(Balancer(Config{ app.Use(Balancer(Config{
@ -230,7 +230,7 @@ func Test_Proxy_Modify_Response(t *testing.T) {
}, },
})) }))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
@ -243,10 +243,10 @@ func Test_Proxy_Modify_Response(t *testing.T) {
func Test_Proxy_Modify_Request(t *testing.T) { func Test_Proxy_Modify_Request(t *testing.T) {
t.Parallel() t.Parallel()
_, addr := createProxyTestServer(func(c *fiber.Ctx) error { _, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
b := c.Request().Body() b := c.Request().Body()
return c.SendString(string(b)) return c.SendString(string(b))
}, t) })
app := fiber.New() app := fiber.New()
app.Use(Balancer(Config{ app.Use(Balancer(Config{
@ -257,7 +257,7 @@ func Test_Proxy_Modify_Request(t *testing.T) {
}, },
})) }))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
@ -270,10 +270,10 @@ func Test_Proxy_Modify_Request(t *testing.T) {
func Test_Proxy_Timeout_Slow_Server(t *testing.T) { func Test_Proxy_Timeout_Slow_Server(t *testing.T) {
t.Parallel() t.Parallel()
_, addr := createProxyTestServer(func(c *fiber.Ctx) error { _, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
return c.SendString("fiber is awesome") return c.SendString("fiber is awesome")
}, t) })
app := fiber.New() app := fiber.New()
app.Use(Balancer(Config{ app.Use(Balancer(Config{
@ -281,7 +281,7 @@ func Test_Proxy_Timeout_Slow_Server(t *testing.T) {
Timeout: 3 * time.Second, Timeout: 3 * time.Second,
})) }))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil), 5000) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil), 5000)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
@ -294,10 +294,10 @@ func Test_Proxy_Timeout_Slow_Server(t *testing.T) {
func Test_Proxy_With_Timeout(t *testing.T) { func Test_Proxy_With_Timeout(t *testing.T) {
t.Parallel() t.Parallel()
_, addr := createProxyTestServer(func(c *fiber.Ctx) error { _, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
return c.SendString("fiber is awesome") return c.SendString("fiber is awesome")
}, t) })
app := fiber.New() app := fiber.New()
app.Use(Balancer(Config{ app.Use(Balancer(Config{
@ -305,7 +305,7 @@ func Test_Proxy_With_Timeout(t *testing.T) {
Timeout: 100 * time.Millisecond, Timeout: 100 * time.Millisecond,
})) }))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil), 2000) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil), 2000)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode) utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
@ -318,16 +318,16 @@ func Test_Proxy_With_Timeout(t *testing.T) {
func Test_Proxy_Buffer_Size_Response(t *testing.T) { func Test_Proxy_Buffer_Size_Response(t *testing.T) {
t.Parallel() t.Parallel()
_, addr := createProxyTestServer(func(c *fiber.Ctx) error { _, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
long := strings.Join(make([]string, 5000), "-") long := strings.Join(make([]string, 5000), "-")
c.Set("Very-Long-Header", long) c.Set("Very-Long-Header", long)
return c.SendString("ok") return c.SendString("ok")
}, t) })
app := fiber.New() app := fiber.New()
app.Use(Balancer(Config{Servers: []string{addr}})) app.Use(Balancer(Config{Servers: []string{addr}}))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode) utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
@ -337,7 +337,7 @@ func Test_Proxy_Buffer_Size_Response(t *testing.T) {
ReadBufferSize: 1024 * 8, ReadBufferSize: 1024 * 8,
})) }))
resp, err = app.Test(httptest.NewRequest("GET", "/", nil)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
} }
@ -357,9 +357,9 @@ func Test_Proxy_Do_RestoreOriginalURL(t *testing.T) {
utils.AssertEqual(t, originalURL, c.OriginalURL()) utils.AssertEqual(t, originalURL, c.OriginalURL())
return c.SendString("ok") return c.SendString("ok")
}) })
_, err1 := app.Test(httptest.NewRequest("GET", "/test", nil)) _, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
// This test requires multiple requests due to zero allocation used in fiber // This test requires multiple requests due to zero allocation used in fiber
_, err2 := app.Test(httptest.NewRequest("GET", "/test", nil)) _, err2 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, nil, err1) utils.AssertEqual(t, nil, err1)
utils.AssertEqual(t, nil, err2) utils.AssertEqual(t, nil, err2)
@ -369,9 +369,9 @@ func Test_Proxy_Do_RestoreOriginalURL(t *testing.T) {
func Test_Proxy_Do_HTTP_Prefix_URL(t *testing.T) { func Test_Proxy_Do_HTTP_Prefix_URL(t *testing.T) {
t.Parallel() t.Parallel()
_, addr := createProxyTestServer(func(c *fiber.Ctx) error { _, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
return c.SendString("hello world") return c.SendString("hello world")
}, t) })
app := fiber.New(fiber.Config{DisableStartupMessage: true}) app := fiber.New(fiber.Config{DisableStartupMessage: true})
app.Get("/*", func(c *fiber.Ctx) error { app.Get("/*", func(c *fiber.Ctx) error {
@ -386,7 +386,7 @@ func Test_Proxy_Do_HTTP_Prefix_URL(t *testing.T) {
return nil return nil
}) })
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/http://"+addr, nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/http://"+addr, nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
s, err := io.ReadAll(resp.Body) s, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
@ -431,9 +431,8 @@ func Test_Proxy_Forward_Local_Client(t *testing.T) {
app.Use(Forward("http://"+addr+"/test_local_client", &fasthttp.Client{ app.Use(Forward("http://"+addr+"/test_local_client", &fasthttp.Client{
NoDefaultUserAgentHeader: true, NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true, DisablePathNormalizing: true,
Dial: func(addr string) (net.Conn, error) {
return fasthttp.Dial(addr) Dial: fasthttp.Dial,
},
})) }))
go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }()
@ -447,11 +446,11 @@ func Test_Proxy_Forward_Local_Client(t *testing.T) {
func Test_ProxyBalancer_Custom_Client(t *testing.T) { func Test_ProxyBalancer_Custom_Client(t *testing.T) {
t.Parallel() t.Parallel()
target, addr := createProxyTestServer( target, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }, t, return c.SendStatus(fiber.StatusTeapot)
) })
resp, err := target.Test(httptest.NewRequest("GET", "/", nil), 2000) resp, err := target.Test(httptest.NewRequest(fiber.MethodGet, "/", nil), 2000)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode) utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
@ -468,7 +467,7 @@ func Test_ProxyBalancer_Custom_Client(t *testing.T) {
Timeout: time.Second, Timeout: time.Second,
}})) }}))
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Host = addr req.Host = addr
resp, err = app.Test(req) resp, err = app.Test(req)
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)

View File

@ -1,4 +1,4 @@
package recover package recover //nolint:predeclared // TODO: Rename to some non-builtin
import ( import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
@ -23,6 +23,8 @@ type Config struct {
} }
// ConfigDefault is the default config // ConfigDefault is the default config
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{ var ConfigDefault = Config{
Next: nil, Next: nil,
EnableStackTrace: false, EnableStackTrace: false,

View File

@ -1,4 +1,4 @@
package recover package recover //nolint:predeclared // TODO: Rename to some non-builtin
import ( import (
"fmt" "fmt"
@ -9,7 +9,7 @@ import (
) )
func defaultStackTraceHandler(_ *fiber.Ctx, e interface{}) { func defaultStackTraceHandler(_ *fiber.Ctx, e interface{}) {
_, _ = os.Stderr.WriteString(fmt.Sprintf("panic: %v\n%s\n", e, debug.Stack())) _, _ = os.Stderr.WriteString(fmt.Sprintf("panic: %v\n%s\n", e, debug.Stack())) //nolint:errcheck // This will never fail
} }
// New creates a new middleware handler // New creates a new middleware handler
@ -18,7 +18,7 @@ func New(config ...Config) fiber.Handler {
cfg := configDefault(config...) cfg := configDefault(config...)
// Return new handler // Return new handler
return func(c *fiber.Ctx) (err error) { return func(c *fiber.Ctx) (err error) { //nolint:nonamedreturns // Uses recover() to overwrite the error
// Don't execute middleware if Next returns true // Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) { if cfg.Next != nil && cfg.Next(c) {
return c.Next() return c.Next()

View File

@ -1,4 +1,4 @@
package recover package recover //nolint:predeclared // TODO: Rename to some non-builtin
import ( import (
"net/http/httptest" "net/http/httptest"
@ -24,7 +24,7 @@ func Test_Recover(t *testing.T) {
panic("Hi, I'm an error!") panic("Hi, I'm an error!")
}) })
resp, err := app.Test(httptest.NewRequest("GET", "/panic", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/panic", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode) utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
} }
@ -39,7 +39,7 @@ func Test_Recover_Next(t *testing.T) {
}, },
})) }))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode) utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
} }
@ -55,7 +55,7 @@ func Test_Recover_EnableStackTrace(t *testing.T) {
panic("Hi, I'm an error!") panic("Hi, I'm an error!")
}) })
resp, err := app.Test(httptest.NewRequest("GET", "/panic", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/panic", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode) utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
} }

View File

@ -33,6 +33,8 @@ type Config struct {
// It uses a fast UUID generator which will expose the number of // It uses a fast UUID generator which will expose the number of
// requests made to the server. To conceal this value for better // requests made to the server. To conceal this value for better
// privacy, use the "utils.UUIDv4" generator. // privacy, use the "utils.UUIDv4" generator.
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{ var ConfigDefault = Config{
Next: nil, Next: nil,
Header: fiber.HeaderXRequestID, Header: fiber.HeaderXRequestID,

View File

@ -19,14 +19,14 @@ func Test_RequestID(t *testing.T) {
return c.SendString("Hello, World 👋!") return c.SendString("Hello, World 👋!")
}) })
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
reqid := resp.Header.Get(fiber.HeaderXRequestID) reqid := resp.Header.Get(fiber.HeaderXRequestID)
utils.AssertEqual(t, 36, len(reqid)) utils.AssertEqual(t, 36, len(reqid))
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Add(fiber.HeaderXRequestID, reqid) req.Header.Add(fiber.HeaderXRequestID, reqid)
resp, err = app.Test(req) resp, err = app.Test(req)
@ -45,7 +45,7 @@ func Test_RequestID_Next(t *testing.T) {
}, },
})) }))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, resp.Header.Get(fiber.HeaderXRequestID), "") utils.AssertEqual(t, resp.Header.Get(fiber.HeaderXRequestID), "")
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode) utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
@ -54,13 +54,13 @@ func Test_RequestID_Next(t *testing.T) {
// go test -run Test_RequestID_Locals // go test -run Test_RequestID_Locals
func Test_RequestID_Locals(t *testing.T) { func Test_RequestID_Locals(t *testing.T) {
t.Parallel() t.Parallel()
reqId := "ThisIsARequestId" reqID := "ThisIsARequestId"
ctxKey := "ThisIsAContextKey" ctxKey := "ThisIsAContextKey"
app := fiber.New() app := fiber.New()
app.Use(New(Config{ app.Use(New(Config{
Generator: func() string { Generator: func() string {
return reqId return reqID
}, },
ContextKey: ctxKey, ContextKey: ctxKey,
})) }))
@ -68,11 +68,11 @@ func Test_RequestID_Locals(t *testing.T) {
var ctxVal string var ctxVal string
app.Use(func(c *fiber.Ctx) error { app.Use(func(c *fiber.Ctx) error {
ctxVal = c.Locals(ctxKey).(string) ctxVal = c.Locals(ctxKey).(string) //nolint:forcetypeassert,errcheck // We always store a string in here
return c.Next() return c.Next()
}) })
_, err := app.Test(httptest.NewRequest("GET", "/", nil)) _, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, reqId, ctxVal) utils.AssertEqual(t, reqID, ctxVal)
} }

View File

@ -148,7 +148,7 @@ type Config struct {
// Optional. Default value utils.UUID // Optional. Default value utils.UUID
KeyGenerator func() string KeyGenerator func() string
// Deprecated, please use KeyLookup // Deprecated: Please use KeyLookup
CookieName string CookieName string
// Source defines where to obtain the session id // Source defines where to obtain the session id

View File

@ -2,6 +2,7 @@ package session
import ( import (
"fmt" "fmt"
"log"
"strings" "strings"
"time" "time"
@ -49,7 +50,7 @@ type Config struct {
// Optional. Default value utils.UUIDv4 // Optional. Default value utils.UUIDv4
KeyGenerator func() string KeyGenerator func() string
// Deprecated, please use KeyLookup // Deprecated: Please use KeyLookup
CookieName string CookieName string
// Source defines where to obtain the session id // Source defines where to obtain the session id
@ -68,8 +69,10 @@ const (
) )
// ConfigDefault is the default config // ConfigDefault is the default config
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{ var ConfigDefault = Config{
Expiration: 24 * time.Hour, Expiration: 24 * time.Hour, //nolint:gomnd // No magic number, just the default config
KeyLookup: "cookie:session_id", KeyLookup: "cookie:session_id",
KeyGenerator: utils.UUIDv4, KeyGenerator: utils.UUIDv4,
source: "cookie", source: "cookie",
@ -91,7 +94,7 @@ func configDefault(config ...Config) Config {
cfg.Expiration = ConfigDefault.Expiration cfg.Expiration = ConfigDefault.Expiration
} }
if cfg.CookieName != "" { if cfg.CookieName != "" {
fmt.Println("[session] CookieName is deprecated, please use KeyLookup") log.Printf("[session] CookieName is deprecated, please use KeyLookup\n")
cfg.KeyLookup = fmt.Sprintf("cookie:%s", cfg.CookieName) cfg.KeyLookup = fmt.Sprintf("cookie:%s", cfg.CookieName)
} }
if cfg.KeyLookup == "" { if cfg.KeyLookup == "" {
@ -102,7 +105,8 @@ func configDefault(config ...Config) Config {
} }
selectors := strings.Split(cfg.KeyLookup, ":") selectors := strings.Split(cfg.KeyLookup, ":")
if len(selectors) != 2 { const numSelectors = 2
if len(selectors) != numSelectors {
panic("[session] KeyLookup must in the form of <source>:<name>") panic("[session] KeyLookup must in the form of <source>:<name>")
} }
switch Source(selectors[0]) { switch Source(selectors[0]) {

View File

@ -13,6 +13,7 @@ type data struct {
Data map[string]interface{} Data map[string]interface{}
} }
//nolint:gochecknoglobals // TODO: Do not use a global var here
var dataPool = sync.Pool{ var dataPool = sync.Pool{
New: func() interface{} { New: func() interface{} {
d := new(data) d := new(data)
@ -22,7 +23,7 @@ var dataPool = sync.Pool{
} }
func acquireData() *data { func acquireData() *data {
return dataPool.Get().(*data) return dataPool.Get().(*data) //nolint:forcetypeassert // We store nothing else in the pool
} }
func (d *data) Reset() { func (d *data) Reset() {

View File

@ -3,11 +3,13 @@ package session
import ( import (
"bytes" "bytes"
"encoding/gob" "encoding/gob"
"fmt"
"sync" "sync"
"time" "time"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -21,6 +23,7 @@ type Session struct {
exp time.Duration // expiration of this session exp time.Duration // expiration of this session
} }
//nolint:gochecknoglobals // TODO: Do not use a global var here
var sessionPool = sync.Pool{ var sessionPool = sync.Pool{
New: func() interface{} { New: func() interface{} {
return new(Session) return new(Session)
@ -28,7 +31,7 @@ var sessionPool = sync.Pool{
} }
func acquireSession() *Session { func acquireSession() *Session {
s := sessionPool.Get().(*Session) s := sessionPool.Get().(*Session) //nolint:forcetypeassert,errcheck // We store nothing else in the pool
if s.data == nil { if s.data == nil {
s.data = acquireData() s.data = acquireData()
} }
@ -153,7 +156,7 @@ func (s *Session) Save() error {
encCache := gob.NewEncoder(s.byteBuffer) encCache := gob.NewEncoder(s.byteBuffer)
err := encCache.Encode(&s.data.Data) err := encCache.Encode(&s.data.Data)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to encode data: %w", err)
} }
// copy the data in buffer // copy the data in buffer

View File

@ -7,6 +7,7 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/storage/memory" "github.com/gofiber/fiber/v2/internal/storage/memory"
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -94,6 +95,8 @@ func Test_Session(t *testing.T) {
} }
// go test -run Test_Session_Types // go test -run Test_Session_Types
//
//nolint:forcetypeassert // TODO: Do not force-type assert
func Test_Session_Types(t *testing.T) { func Test_Session_Types(t *testing.T) {
t.Parallel() t.Parallel()
@ -127,25 +130,27 @@ func Test_Session_Types(t *testing.T) {
Name: "John", Name: "John",
} }
// set value // set value
var vbool = true var (
var vstring = "str" vbool = true
var vint = 13 vstring = "str"
var vint8 int8 = 13 vint = 13
var vint16 int16 = 13 vint8 int8 = 13
var vint32 int32 = 13 vint16 int16 = 13
var vint64 int64 = 13 vint32 int32 = 13
var vuint uint = 13 vint64 int64 = 13
var vuint8 uint8 = 13 vuint uint = 13
var vuint16 uint16 = 13 vuint8 uint8 = 13
var vuint32 uint32 = 13 vuint16 uint16 = 13
var vuint64 uint64 = 13 vuint32 uint32 = 13
var vuintptr uintptr = 13 vuint64 uint64 = 13
var vbyte byte = 'k' vuintptr uintptr = 13
var vrune rune = 'k' vbyte byte = 'k'
var vfloat32 float32 = 13 vrune = 'k'
var vfloat64 float64 = 13 vfloat32 float32 = 13
var vcomplex64 complex64 = 13 vfloat64 float64 = 13
var vcomplex128 complex128 = 13 vcomplex64 complex64 = 13
vcomplex128 complex128 = 13
)
sess.Set("vuser", vuser) sess.Set("vuser", vuser)
sess.Set("vbool", vbool) sess.Set("vbool", vbool)
sess.Set("vstring", vstring) sess.Set("vstring", vstring)
@ -212,7 +217,8 @@ func Test_Session_Store_Reset(t *testing.T) {
defer app.ReleaseCtx(ctx) defer app.ReleaseCtx(ctx)
// get session // get session
sess, _ := store.Get(ctx) sess, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
// make sure its new // make sure its new
utils.AssertEqual(t, true, sess.Fresh()) utils.AssertEqual(t, true, sess.Fresh())
// set value & save // set value & save
@ -224,7 +230,8 @@ func Test_Session_Store_Reset(t *testing.T) {
utils.AssertEqual(t, nil, store.Reset()) utils.AssertEqual(t, nil, store.Reset())
// make sure the session is recreated // make sure the session is recreated
sess, _ = store.Get(ctx) sess, err = store.Get(ctx)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, sess.Fresh()) utils.AssertEqual(t, true, sess.Fresh())
utils.AssertEqual(t, nil, sess.Get("hello")) utils.AssertEqual(t, nil, sess.Get("hello"))
} }
@ -242,12 +249,13 @@ func Test_Session_Save(t *testing.T) {
ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx) defer app.ReleaseCtx(ctx)
// get session // get session
sess, _ := store.Get(ctx) sess, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
// set value // set value
sess.Set("name", "john") sess.Set("name", "john")
// save session // save session
err := sess.Save() err = sess.Save()
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
}) })
@ -262,12 +270,13 @@ func Test_Session_Save(t *testing.T) {
ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx) defer app.ReleaseCtx(ctx)
// get session // get session
sess, _ := store.Get(ctx) sess, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
// set value // set value
sess.Set("name", "john") sess.Set("name", "john")
// save session // save session
err := sess.Save() err = sess.Save()
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, store.getSessionID(ctx), string(ctx.Response().Header.Peek(store.sessionName))) utils.AssertEqual(t, store.getSessionID(ctx), string(ctx.Response().Header.Peek(store.sessionName)))
utils.AssertEqual(t, store.getSessionID(ctx), string(ctx.Request().Header.Peek(store.sessionName))) utils.AssertEqual(t, store.getSessionID(ctx), string(ctx.Request().Header.Peek(store.sessionName)))
@ -287,7 +296,8 @@ func Test_Session_Save_Expiration(t *testing.T) {
ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx) defer app.ReleaseCtx(ctx)
// get session // get session
sess, _ := store.Get(ctx) sess, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
// set value // set value
sess.Set("name", "john") sess.Set("name", "john")
@ -295,18 +305,20 @@ func Test_Session_Save_Expiration(t *testing.T) {
sess.SetExpiry(time.Second * 5) sess.SetExpiry(time.Second * 5)
// save session // save session
err := sess.Save() err = sess.Save()
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
// here you need to get the old session yet // here you need to get the old session yet
sess, _ = store.Get(ctx) sess, err = store.Get(ctx)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "john", sess.Get("name")) utils.AssertEqual(t, "john", sess.Get("name"))
// just to make sure the session has been expired // just to make sure the session has been expired
time.Sleep(time.Second * 5) time.Sleep(time.Second * 5)
// here you should get a new session // here you should get a new session
sess, _ = store.Get(ctx) sess, err = store.Get(ctx)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, nil, sess.Get("name")) utils.AssertEqual(t, nil, sess.Get("name"))
}) })
} }
@ -325,7 +337,8 @@ func Test_Session_Reset(t *testing.T) {
ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx) defer app.ReleaseCtx(ctx)
// get session // get session
sess, _ := store.Get(ctx) sess, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
sess.Set("name", "fenny") sess.Set("name", "fenny")
utils.AssertEqual(t, nil, sess.Destroy()) utils.AssertEqual(t, nil, sess.Destroy())
@ -345,14 +358,16 @@ func Test_Session_Reset(t *testing.T) {
ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx) defer app.ReleaseCtx(ctx)
// get session // get session
sess, _ := store.Get(ctx) sess, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
// set value & save // set value & save
sess.Set("name", "fenny") sess.Set("name", "fenny")
utils.AssertEqual(t, nil, sess.Save()) utils.AssertEqual(t, nil, sess.Save())
sess, _ = store.Get(ctx) sess, err = store.Get(ctx)
utils.AssertEqual(t, nil, err)
err := sess.Destroy() err = sess.Destroy()
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "", string(ctx.Response().Header.Peek(store.sessionName))) utils.AssertEqual(t, "", string(ctx.Response().Header.Peek(store.sessionName)))
utils.AssertEqual(t, "", string(ctx.Request().Header.Peek(store.sessionName))) utils.AssertEqual(t, "", string(ctx.Request().Header.Peek(store.sessionName)))
@ -383,7 +398,8 @@ func Test_Session_Cookie(t *testing.T) {
defer app.ReleaseCtx(ctx) defer app.ReleaseCtx(ctx)
// get session // get session
sess, _ := store.Get(ctx) sess, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, nil, sess.Save()) utils.AssertEqual(t, nil, sess.Save())
// cookie should be set on Save ( even if empty data ) // cookie should be set on Save ( even if empty data )
@ -401,12 +417,14 @@ func Test_Session_Cookie_In_Response(t *testing.T) {
defer app.ReleaseCtx(ctx) defer app.ReleaseCtx(ctx)
// get session // get session
sess, _ := store.Get(ctx) sess, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
sess.Set("id", "1") sess.Set("id", "1")
utils.AssertEqual(t, true, sess.Fresh()) utils.AssertEqual(t, true, sess.Fresh())
utils.AssertEqual(t, nil, sess.Save()) utils.AssertEqual(t, nil, sess.Save())
sess, _ = store.Get(ctx) sess, err = store.Get(ctx)
utils.AssertEqual(t, nil, err)
sess.Set("name", "john") sess.Set("name", "john")
utils.AssertEqual(t, true, sess.Fresh()) utils.AssertEqual(t, true, sess.Fresh())
@ -497,7 +515,7 @@ func Benchmark_Session(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
b.ResetTimer() b.ResetTimer()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
sess, _ := store.Get(c) sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
sess.Set("john", "doe") sess.Set("john", "doe")
err = sess.Save() err = sess.Save()
} }
@ -512,7 +530,7 @@ func Benchmark_Session(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
b.ResetTimer() b.ResetTimer()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
sess, _ := store.Get(c) sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
sess.Set("john", "doe") sess.Set("john", "doe")
err = sess.Save() err = sess.Save()
} }

View File

@ -2,11 +2,13 @@ package session
import ( import (
"encoding/gob" "encoding/gob"
"fmt"
"sync" "sync"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/storage/memory" "github.com/gofiber/fiber/v2/internal/storage/memory"
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -14,6 +16,7 @@ type Store struct {
Config Config
} }
//nolint:gochecknoglobals // TODO: Do not use a global var here
var mux sync.Mutex var mux sync.Mutex
func New(config ...Config) *Store { func New(config ...Config) *Store {
@ -31,7 +34,7 @@ func New(config ...Config) *Store {
// RegisterType will allow you to encode/decode custom types // RegisterType will allow you to encode/decode custom types
// into any Storage provider // into any Storage provider
func (s *Store) RegisterType(i interface{}) { func (*Store) RegisterType(i interface{}) {
gob.Register(i) gob.Register(i)
} }
@ -70,11 +73,11 @@ func (s *Store) Get(c *fiber.Ctx) (*Session, error) {
if raw != nil && err == nil { if raw != nil && err == nil {
mux.Lock() mux.Lock()
defer mux.Unlock() defer mux.Unlock()
_, _ = sess.byteBuffer.Write(raw) _, _ = sess.byteBuffer.Write(raw) //nolint:errcheck // This will never fail
encCache := gob.NewDecoder(sess.byteBuffer) encCache := gob.NewDecoder(sess.byteBuffer)
err := encCache.Decode(&sess.data.Data) err := encCache.Decode(&sess.data.Data)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to decode session data: %w", err)
} }
} else if err != nil { } else if err != nil {
return nil, err return nil, err

View File

@ -6,6 +6,7 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )

View File

@ -1,6 +1,8 @@
package skip package skip
import "github.com/gofiber/fiber/v2" import (
"github.com/gofiber/fiber/v2"
)
// New creates a middleware handler which skips the wrapped handler // New creates a middleware handler which skips the wrapped handler
// if the exclude predicate returns true. // if the exclude predicate returns true.

View File

@ -17,7 +17,7 @@ func Test_Skip(t *testing.T) {
app.Use(skip.New(errTeapotHandler, func(*fiber.Ctx) bool { return true })) app.Use(skip.New(errTeapotHandler, func(*fiber.Ctx) bool { return true }))
app.Get("/", helloWorldHandler) app.Get("/", helloWorldHandler)
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
} }
@ -30,7 +30,7 @@ func Test_SkipFalse(t *testing.T) {
app.Use(skip.New(errTeapotHandler, func(*fiber.Ctx) bool { return false })) app.Use(skip.New(errTeapotHandler, func(*fiber.Ctx) bool { return false }))
app.Get("/", helloWorldHandler) app.Get("/", helloWorldHandler)
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode) utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
} }
@ -43,7 +43,7 @@ func Test_SkipNilFunc(t *testing.T) {
app.Use(skip.New(errTeapotHandler, nil)) app.Use(skip.New(errTeapotHandler, nil))
app.Get("/", helloWorldHandler) app.Get("/", helloWorldHandler)
resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode) utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
} }

View File

@ -18,7 +18,8 @@ func Test_Timeout(t *testing.T) {
// fiber instance // fiber instance
app := fiber.New() app := fiber.New()
h := New(func(c *fiber.Ctx) error { h := New(func(c *fiber.Ctx) error {
sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms") sleepTime, err := time.ParseDuration(c.Params("sleepTime") + "ms")
utils.AssertEqual(t, nil, err)
if err := sleepWithContext(c.UserContext(), sleepTime, context.DeadlineExceeded); err != nil { if err := sleepWithContext(c.UserContext(), sleepTime, context.DeadlineExceeded); err != nil {
return fmt.Errorf("%w: l2 wrap", fmt.Errorf("%w: l1 wrap ", err)) return fmt.Errorf("%w: l2 wrap", fmt.Errorf("%w: l1 wrap ", err))
} }
@ -26,12 +27,12 @@ func Test_Timeout(t *testing.T) {
}, 100*time.Millisecond) }, 100*time.Millisecond)
app.Get("/test/:sleepTime", h) app.Get("/test/:sleepTime", h)
testTimeout := func(timeoutStr string) { testTimeout := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code") utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
} }
testSucces := func(timeoutStr string) { testSucces := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code") utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
} }
@ -49,7 +50,8 @@ func Test_TimeoutWithCustomError(t *testing.T) {
// fiber instance // fiber instance
app := fiber.New() app := fiber.New()
h := New(func(c *fiber.Ctx) error { h := New(func(c *fiber.Ctx) error {
sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms") sleepTime, err := time.ParseDuration(c.Params("sleepTime") + "ms")
utils.AssertEqual(t, nil, err)
if err := sleepWithContext(c.UserContext(), sleepTime, ErrFooTimeOut); err != nil { if err := sleepWithContext(c.UserContext(), sleepTime, ErrFooTimeOut); err != nil {
return fmt.Errorf("%w: execution error", err) return fmt.Errorf("%w: execution error", err)
} }
@ -57,12 +59,12 @@ func Test_TimeoutWithCustomError(t *testing.T) {
}, 100*time.Millisecond, ErrFooTimeOut) }, 100*time.Millisecond, ErrFooTimeOut)
app.Get("/test/:sleepTime", h) app.Get("/test/:sleepTime", h)
testTimeout := func(timeoutStr string) { testTimeout := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code") utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
} }
testSucces := func(timeoutStr string) { testSucces := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code") utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
} }

Some files were not shown because too many files have changed in this diff Show More