🚀 Feature: Add and apply more stricter golangci-lint linting rules (#2286)

* golangci-lint: add and apply more stricter linting rules

* github: drop security workflow now that we use gosec linter inside golangci-lint

* github: use official golangci-lint CI linter

* Add editorconfig and gitattributes file
pull/2313/head
leonklingele 2023-01-27 09:01:37 +01:00 committed by GitHub
parent 7327a17951
commit 167a8b5e94
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
120 changed files with 1889 additions and 1321 deletions

8
.editorconfig Normal file
View File

@ -0,0 +1,8 @@
; This file is for unifying the coding style for different editors and IDEs.
; More information at http://editorconfig.org
; This style originates from https://github.com/fewagency/best-practices
root = true
[*]
charset = utf-8
end_of_line = lf

12
.gitattributes vendored Normal file
View File

@ -0,0 +1,12 @@
# Handle line endings automatically for files detected as text
# and leave all files detected as binary untouched.
* text=auto eol=lf
# Force batch scripts to always use CRLF line endings so that if a repo is accessed
# in Windows via a file share from Linux, the scripts will work.
*.{cmd,[cC][mM][dD]} text eol=crlf
*.{bat,[bB][aA][tT]} text eol=crlf
# Force bash scripts to always use LF line endings so that if a repo is accessed
# in Unix via a file share from Windows, the scripts will work.
*.sh text eol=lf

View File

@ -1,17 +1,28 @@
# Adapted from https://github.com/golangci/golangci-lint-action/blob/b56f6f529003f1c81d4d759be6bd5f10bf9a0fa0/README.md#how-to-use
name: golangci-lint
on:
push:
branches:
- master
- main
pull_request:
name: Linter
push:
tags:
- v*
branches:
- master
- main
pull_request:
permissions:
contents: read
jobs:
Golint:
golangci:
name: lint
runs-on: ubuntu-latest
steps:
- name: Fetch Repository
uses: actions/checkout@v3
- name: Run Golint
uses: reviewdog/action-golangci-lint@v2
- uses: actions/setup-go@v3
with:
golangci_lint_flags: "--tests=false"
# NOTE: Keep this in sync with the version from go.mod
go-version: 1.19
- uses: actions/checkout@v3
- name: golangci-lint
uses: golangci/golangci-lint-action@v3
with:
# NOTE: Keep this in sync with the version from .golangci.yml
version: v1.50.1

View File

@ -1,17 +0,0 @@
on:
push:
branches:
- master
- main
pull_request:
name: Security
jobs:
Gosec:
runs-on: ubuntu-latest
steps:
- name: Fetch Repository
uses: actions/checkout@v3
- name: Run Gosec
uses: securego/gosec@master
with:
args: -exclude-dir=internal/*/ ./...

258
.golangci.yml Normal file
View File

@ -0,0 +1,258 @@
# Created based on v1.50.1
# NOTE: Keep this in sync with the version in .github/workflows/linter.yml
run:
modules-download-mode: readonly
skip-dirs-use-default: false
skip-dirs:
- internal # TODO: Also apply proper linting for internal dir
output:
sort-results: true
linters-settings:
# TODO: Eventually enable these checks
# depguard:
# include-go-root: true
# packages:
# - flag
# - io/ioutil
# - reflect
# - unsafe
# packages-with-error-message:
# - flag: '`flag` package is only allowed in main.go'
# - io/ioutil: '`io/ioutil` package is deprecated, use the `io` and `os` package instead'
# - reflect: '`reflect` package is dangerous to use'
# - unsafe: '`unsafe` package is dangerous to use'
errcheck:
check-type-assertions: true
check-blank: true
disable-default-exclusions: true
errchkjson:
report-no-exported: true
exhaustive:
default-signifies-exhaustive: true
forbidigo:
forbid:
- ^(fmt\.Print(|f|ln)|print|println)$
- 'http\.Default(Client|Transport)'
# TODO: Eventually enable these patterns
# - 'time\.Sleep'
# - 'panic'
gci:
sections:
- standard
- prefix(github.com/gofiber/fiber)
- default
- blank
- dot
custom-order: true
gocritic:
disabled-checks:
- ifElseChain
gofumpt:
module-path: github.com/gofiber/fiber
extra-rules: true
gosec:
config:
global:
audit: true
govet:
check-shadowing: true
enable-all: true
disable:
- shadow
- fieldalignment
grouper:
import-require-single-import: true
import-require-grouping: true
misspell:
locale: US
nolintlint:
require-explanation: true
require-specific: true
nonamedreturns:
report-error-in-defer: true
predeclared:
q: true
promlinter:
strict: true
revive:
enable-all-rules: true
rules:
# Provided by gomnd linter
- name: add-constant
disabled: true
- name: argument-limit
disabled: true
# Provided by bidichk
- name: banned-characters
disabled: true
- name: cognitive-complexity
disabled: true
- name: cyclomatic
disabled: true
- name: exported
disabled: true
- name: file-header
disabled: true
- name: function-result-limit
disabled: true
- name: function-length
disabled: true
- name: line-length-limit
disabled: true
- name: max-public-structs
disabled: true
- name: modifies-parameter
disabled: true
- name: nested-structs
disabled: true
- name: package-comments
disabled: true
stylecheck:
checks:
- all
- -ST1000
- -ST1020
- -ST1021
- -ST1022
tagliatelle:
case:
rules:
json: snake
#tenv:
# all: true
#unparam:
# check-exported: true
wrapcheck:
ignorePackageGlobs:
- github.com/gofiber/fiber/*
- github.com/valyala/fasthttp
issues:
exclude-use-default: false
linters:
enable:
- asasalint
- asciicheck
- bidichk
- bodyclose
- containedctx
- contextcheck
# - cyclop
# - deadcode
# - decorder
- depguard
- dogsled
# - dupl
# - dupword
- durationcheck
- errcheck
- errchkjson
- errname
- errorlint
- execinquery
- exhaustive
# - exhaustivestruct
# - exhaustruct
- exportloopref
- forbidigo
- forcetypeassert
# - funlen
- gci
- gochecknoglobals
- gochecknoinits
# - gocognit
- goconst
- gocritic
# - gocyclo
# - godot
# - godox
# - goerr113
- gofmt
- gofumpt
# - goheader
- goimports
# - golint
- gomnd
- gomoddirectives
# - gomodguard
- goprintffuncname
- gosec
- gosimple
- govet
- grouper
# - ifshort
# - importas
- ineffassign
# - interfacebloat
# - interfacer
# - ireturn
# - lll
- loggercheck
# - maintidx
# - makezero
# - maligned
- misspell
- nakedret
# - nestif
- nilerr
- nilnil
# - nlreturn
- noctx
- nolintlint
- nonamedreturns
# - nosnakecase
- nosprintfhostport
# - paralleltest # TODO: Enable once https://github.com/gofiber/fiber/issues/2254 is implemented
# - prealloc
- predeclared
- promlinter
- reassign
- revive
- rowserrcheck
# - scopelint
- sqlclosecheck
- staticcheck
# - structcheck
- stylecheck
- tagliatelle
# - tenv # TODO: Enable once we drop support for Go 1.16
# - testableexamples
# - testpackage # TODO: Enable once https://github.com/gofiber/fiber/issues/2252 is implemented
- thelper
# - tparallel # TODO: Enable once https://github.com/gofiber/fiber/issues/2254 is implemented
- typecheck
- unconvert
- unparam
- unused
- usestdlibvars
# - varcheck
# - varnamelen
- wastedassign
- whitespace
- wrapcheck
# - wsl

58
app.go
View File

@ -14,6 +14,7 @@ import (
"encoding/xml"
"errors"
"fmt"
"log"
"net"
"net/http"
"net/http/httputil"
@ -24,6 +25,7 @@ import (
"time"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
@ -306,7 +308,7 @@ type Config struct {
// FEATURE: v2.3.x
// The router executes the same handler by default if StrictRouting or CaseSensitive is disabled.
// Enabling RedirectFixedPath will change this behaviour into a client redirect to the original route path.
// Enabling RedirectFixedPath will change this behavior into a client redirect to the original route path.
// Using the status code 301 for GET requests and 308 for all other request methods.
//
// Default: false
@ -454,6 +456,8 @@ const (
)
// HTTP methods enabled by default
//
//nolint:gochecknoglobals // Using a global var is fine here
var DefaultMethods = []string{
MethodGet,
MethodHead,
@ -467,7 +471,7 @@ var DefaultMethods = []string{
}
// DefaultErrorHandler that process return errors from handlers
var DefaultErrorHandler = func(c *Ctx, err error) error {
func DefaultErrorHandler(c *Ctx, err error) error {
code := StatusInternalServerError
var e *Error
if errors.As(err, &e) {
@ -519,7 +523,7 @@ func New(config ...Config) *App {
if app.config.ETag {
if !IsChild() {
fmt.Println("[Warning] Config.ETag is deprecated since v2.0.6, please use 'middleware/etag'.")
log.Printf("[Warning] Config.ETag is deprecated since v2.0.6, please use 'middleware/etag'.\n")
}
}
@ -587,7 +591,7 @@ func (app *App) handleTrustedProxy(ipAddress string) {
if strings.Contains(ipAddress, "/") {
_, ipNet, err := net.ParseCIDR(ipAddress)
if err != nil {
fmt.Printf("[Warning] IP range %q could not be parsed: %v\n", ipAddress, err)
log.Printf("[Warning] IP range %q could not be parsed: %v\n", ipAddress, err)
} else {
app.config.trustedProxyRanges = append(app.config.trustedProxyRanges, ipNet)
}
@ -822,7 +826,7 @@ func (app *App) Config() Config {
}
// Handler returns the server handler.
func (app *App) Handler() fasthttp.RequestHandler {
func (app *App) Handler() fasthttp.RequestHandler { //revive:disable-line:confusing-naming // Having both a Handler() (uppercase) and a handler() (lowercase) is fine. TODO: Use nolint:revive directive instead. See https://github.com/golangci/golangci-lint/issues/3476
// prepare the server for the start
app.startupProcess()
return app.handler
@ -887,7 +891,7 @@ func (app *App) Hooks() *Hooks {
// Test is used for internal debugging by passing a *http.Request.
// Timeout is optional and defaults to 1s, -1 will disable it completely.
func (app *App) Test(req *http.Request, msTimeout ...int) (resp *http.Response, err error) {
func (app *App) Test(req *http.Request, msTimeout ...int) (*http.Response, error) {
// Set timeout
timeout := 1000
if len(msTimeout) > 0 {
@ -902,15 +906,15 @@ func (app *App) Test(req *http.Request, msTimeout ...int) (resp *http.Response,
// Dump raw http request
dump, err := httputil.DumpRequest(req, true)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to dump request: %w", err)
}
// Create test connection
conn := new(testConn)
// Write raw http request
if _, err = conn.r.Write(dump); err != nil {
return nil, err
if _, err := conn.r.Write(dump); err != nil {
return nil, fmt.Errorf("failed to write: %w", err)
}
// prepare the server for the start
app.startupProcess()
@ -943,7 +947,7 @@ func (app *App) Test(req *http.Request, msTimeout ...int) (resp *http.Response,
}
// Check for errors
if err != nil && err != fasthttp.ErrGetOnly {
if err != nil && !errors.Is(err, fasthttp.ErrGetOnly) {
return nil, err
}
@ -951,12 +955,17 @@ func (app *App) Test(req *http.Request, msTimeout ...int) (resp *http.Response,
buffer := bufio.NewReader(&conn.w)
// Convert raw http response to *http.Response
return http.ReadResponse(buffer, req)
res, err := http.ReadResponse(buffer, req)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
return res, nil
}
type disableLogger struct{}
func (dl *disableLogger) Printf(_ string, _ ...interface{}) {
func (*disableLogger) Printf(_ string, _ ...interface{}) {
// fmt.Println(fmt.Sprintf(format, args...))
}
@ -967,7 +976,7 @@ func (app *App) init() *App {
// Only load templates if a view engine is specified
if app.config.Views != nil {
if err := app.config.Views.Load(); err != nil {
fmt.Printf("views: %v\n", err)
log.Printf("[Warning]: failed to load views: %v\n", err)
}
}
@ -1039,25 +1048,30 @@ func (app *App) ErrorHandler(ctx *Ctx, err error) error {
// errors before calling the application's error handler method.
func (app *App) serverErrorHandler(fctx *fasthttp.RequestCtx, err error) {
c := app.AcquireCtx(fctx)
if _, ok := err.(*fasthttp.ErrSmallBuffer); ok {
defer app.ReleaseCtx(c)
var errNetOP *net.OpError
switch {
case errors.As(err, new(*fasthttp.ErrSmallBuffer)):
err = ErrRequestHeaderFieldsTooLarge
} else if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() {
case errors.As(err, &errNetOP) && errNetOP.Timeout():
err = ErrRequestTimeout
} else if err == fasthttp.ErrBodyTooLarge {
case errors.Is(err, fasthttp.ErrBodyTooLarge):
err = ErrRequestEntityTooLarge
} else if err == fasthttp.ErrGetOnly {
case errors.Is(err, fasthttp.ErrGetOnly):
err = ErrMethodNotAllowed
} else if strings.Contains(err.Error(), "timeout") {
case strings.Contains(err.Error(), "timeout"):
err = ErrRequestTimeout
} else {
default:
err = NewError(StatusBadRequest, err.Error())
}
if catch := app.ErrorHandler(c, err); catch != nil {
_ = c.SendStatus(StatusInternalServerError)
log.Printf("serverErrorHandler: failed to call ErrorHandler: %v\n", catch)
_ = c.SendStatus(StatusInternalServerError) //nolint:errcheck // It is fine to ignore the error here
return
}
app.ReleaseCtx(c)
}
// startupProcess Is the method which executes all the necessary processes just before the start of the server.

View File

@ -2,6 +2,7 @@
// 🤖 Github Repository: https://github.com/gofiber/fiber
// 📌 API Documentation: https://docs.gofiber.io
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
package fiber
import (
@ -23,15 +24,16 @@ import (
"time"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/fasthttputil"
)
var testEmptyHandler = func(c *Ctx) error {
func testEmptyHandler(_ *Ctx) error {
return nil
}
func testStatus200(t *testing.T, app *App, url string, method string) {
func testStatus200(t *testing.T, app *App, url, method string) {
t.Helper()
req := httptest.NewRequest(method, url, nil)
@ -42,6 +44,8 @@ func testStatus200(t *testing.T, app *App, url string, method string) {
}
func testErrorResponse(t *testing.T, err error, resp *http.Response, expectedBodyError string) {
t.Helper()
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, 500, resp.StatusCode, "Status code")
@ -140,7 +144,6 @@ func Test_App_ServerErrorHandler_SmallReadBuffer(t *testing.T) {
logHeaderSlice := make([]string, 5000)
request.Header.Set("Very-Long-Header", strings.Join(logHeaderSlice, "-"))
_, err := app.Test(request)
if err == nil {
t.Error("Expect an error at app.Test(request)")
}
@ -470,7 +473,6 @@ func Test_App_Use_MultiplePrefix(t *testing.T) {
body, err = io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "/test/doe", string(body))
}
func Test_App_Use_StrictRouting(t *testing.T) {
@ -515,7 +517,7 @@ func Test_App_Add_Method_Test(t *testing.T) {
}
}()
methods := append(DefaultMethods, "JOHN")
methods := append(DefaultMethods, "JOHN") //nolint:gocritic // We want a new slice here
app := New(Config{
RequestMethods: methods,
})
@ -780,7 +782,7 @@ func Test_App_ShutdownWithTimeout(t *testing.T) {
case <-timer.C:
t.Fatal("idle connections not closed on shutdown")
case err := <-shutdownErr:
if err == nil || err != context.DeadlineExceeded {
if err == nil || !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("unexpected err %v. Expecting %v", err, context.DeadlineExceeded)
}
}
@ -852,7 +854,7 @@ func Test_App_Static_MaxAge(t *testing.T) {
app.Static("/", "./.github", Static{MaxAge: 100})
resp, err := app.Test(httptest.NewRequest("GET", "/index.html", nil))
resp, err := app.Test(httptest.NewRequest(MethodGet, "/index.html", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
utils.AssertEqual(t, false, resp.Header.Get(HeaderContentLength) == "")
@ -866,19 +868,19 @@ func Test_App_Static_Custom_CacheControl(t *testing.T) {
app := New()
app.Static("/", "./.github", Static{ModifyResponse: func(c *Ctx) error {
if strings.Contains(string(c.GetRespHeader("Content-Type")), "text/html") {
if strings.Contains(c.GetRespHeader("Content-Type"), "text/html") {
c.Response().Header.Set("Cache-Control", "no-cache, no-store, must-revalidate")
}
return nil
}})
resp, err := app.Test(httptest.NewRequest("GET", "/index.html", nil))
resp, err := app.Test(httptest.NewRequest(MethodGet, "/index.html", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, "no-cache, no-store, must-revalidate", resp.Header.Get(HeaderCacheControl), "CacheControl Control")
normal_resp, normal_err := app.Test(httptest.NewRequest("GET", "/config.yml", nil))
utils.AssertEqual(t, nil, normal_err, "app.Test(req)")
utils.AssertEqual(t, "", normal_resp.Header.Get(HeaderCacheControl), "CacheControl Control")
respNormal, errNormal := app.Test(httptest.NewRequest(MethodGet, "/config.yml", nil))
utils.AssertEqual(t, nil, errNormal, "app.Test(req)")
utils.AssertEqual(t, "", respNormal.Header.Get(HeaderCacheControl), "CacheControl Control")
}
// go test -run Test_App_Static_Download
@ -890,7 +892,7 @@ func Test_App_Static_Download(t *testing.T) {
app.Static("/fiber.png", "./.github/testdata/fs/img/fiber.png", Static{Download: true})
resp, err := app.Test(httptest.NewRequest("GET", "/fiber.png", nil))
resp, err := app.Test(httptest.NewRequest(MethodGet, "/fiber.png", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
utils.AssertEqual(t, false, resp.Header.Get(HeaderContentLength) == "")
@ -1067,7 +1069,7 @@ func Test_App_Static_Next(t *testing.T) {
t.Run("app.Static is skipped: invoking Get handler", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "/", nil)
req := httptest.NewRequest(MethodGet, "/", nil)
req.Header.Set("X-Custom-Header", "skip")
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
@ -1082,7 +1084,7 @@ func Test_App_Static_Next(t *testing.T) {
t.Run("app.Static is not skipped: serving index.html", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "/", nil)
req := httptest.NewRequest(MethodGet, "/", nil)
req.Header.Set("X-Custom-Header", "don't skip")
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
@ -1379,7 +1381,7 @@ func Test_Test_DumpError(t *testing.T) {
resp, err := app.Test(httptest.NewRequest(MethodGet, "/", errorReader(0)))
utils.AssertEqual(t, true, resp == nil)
utils.AssertEqual(t, "errorReader", err.Error())
utils.AssertEqual(t, "failed to dump request: errorReader", err.Error())
}
// go test -run Test_App_Handler
@ -1393,7 +1395,7 @@ type invalidView struct{}
func (invalidView) Load() error { return errors.New("invalid view") }
func (i invalidView) Render(io.Writer, string, interface{}, ...string) error { panic("implement me") }
func (invalidView) Render(io.Writer, string, interface{}, ...string) error { panic("implement me") }
// go test -run Test_App_Init_Error_View
func Test_App_Init_Error_View(t *testing.T) {
@ -1405,7 +1407,9 @@ func Test_App_Init_Error_View(t *testing.T) {
utils.AssertEqual(t, "implement me", fmt.Sprintf("%v", err))
}
}()
_ = app.config.Views.Render(nil, "", nil)
err := app.config.Views.Render(nil, "", nil)
utils.AssertEqual(t, nil, err)
}
// go test -run Test_App_Stack
@ -1535,11 +1539,12 @@ func Test_App_SmallReadBuffer(t *testing.T) {
go func() {
time.Sleep(500 * time.Millisecond)
resp, err := http.Get("http://127.0.0.1:4006/small-read-buffer")
if resp != nil {
utils.AssertEqual(t, 431, resp.StatusCode)
}
req, err := http.NewRequestWithContext(context.Background(), MethodGet, "http://127.0.0.1:4006/small-read-buffer", http.NoBody)
utils.AssertEqual(t, nil, err)
var client http.Client
resp, err := client.Do(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 431, resp.StatusCode)
utils.AssertEqual(t, nil, app.Shutdown())
}()
@ -1572,13 +1577,13 @@ func Test_App_New_Test_Parallel(t *testing.T) {
t.Run("Test_App_New_Test_Parallel_1", func(t *testing.T) {
t.Parallel()
app := New(Config{Immutable: true})
_, err := app.Test(httptest.NewRequest("GET", "/", nil))
_, err := app.Test(httptest.NewRequest(MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
})
t.Run("Test_App_New_Test_Parallel_2", func(t *testing.T) {
t.Parallel()
app := New(Config{Immutable: true})
_, err := app.Test(httptest.NewRequest("GET", "/", nil))
_, err := app.Test(httptest.NewRequest(MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
})
}
@ -1591,7 +1596,7 @@ func Test_App_ReadBodyStream(t *testing.T) {
return c.SendString(fmt.Sprintf("%v %s", c.Request().IsBodyStream(), c.Body()))
})
testString := "this is a test"
resp, err := app.Test(httptest.NewRequest("POST", "/", bytes.NewBufferString(testString)))
resp, err := app.Test(httptest.NewRequest(MethodPost, "/", bytes.NewBufferString(testString)))
utils.AssertEqual(t, nil, err, "app.Test(req)")
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err, "io.ReadAll(resp.Body)")
@ -1615,12 +1620,12 @@ func Test_App_DisablePreParseMultipartForm(t *testing.T) {
}
file, err := mpf.File["test"][0].Open()
if err != nil {
return err
return fmt.Errorf("failed to open: %w", err)
}
buffer := make([]byte, len(testString))
n, err := file.Read(buffer)
if err != nil {
return err
return fmt.Errorf("failed to read: %w", err)
}
if n != len(testString) {
return fmt.Errorf("bad read length")
@ -1636,7 +1641,7 @@ func Test_App_DisablePreParseMultipartForm(t *testing.T) {
utils.AssertEqual(t, len(testString), n, "writer n")
utils.AssertEqual(t, nil, w.Close(), "w.Close()")
req := httptest.NewRequest("POST", "/", b)
req := httptest.NewRequest(MethodPost, "/", b)
req.Header.Set("Content-Type", w.FormDataContentType())
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err, "app.Test(req)")
@ -1659,7 +1664,7 @@ func Test_App_Test_no_timeout_infinitely(t *testing.T) {
return nil
})
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
req := httptest.NewRequest(MethodGet, "/", http.NoBody)
_, err = app.Test(req, -1)
}()
@ -1696,7 +1701,7 @@ func Test_App_SetTLSHandler(t *testing.T) {
func Test_App_AddCustomRequestMethod(t *testing.T) {
t.Parallel()
methods := append(DefaultMethods, "TEST")
methods := append(DefaultMethods, "TEST") //nolint:gocritic // We want a new slice here
app := New(Config{
RequestMethods: methods,
})

View File

@ -15,6 +15,7 @@ import (
"time"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
@ -50,6 +51,7 @@ type Args = fasthttp.Args
// Copy from fasthttp
type RetryIfFunc = fasthttp.RetryIfFunc
//nolint:gochecknoglobals // TODO: Do not use a global var here
var defaultClient Client
// Client implements http client.
@ -186,11 +188,11 @@ func (a *Agent) Parse() error {
uri := a.req.URI()
isTLS := false
var isTLS bool
scheme := uri.Scheme()
if bytes.Equal(scheme, strHTTPS) {
if bytes.Equal(scheme, []byte(schemeHTTPS)) {
isTLS = true
} else if !bytes.Equal(scheme, strHTTP) {
} else if !bytes.Equal(scheme, []byte(schemeHTTP)) {
return fmt.Errorf("unsupported protocol %q. http and https are supported", scheme)
}
@ -241,7 +243,7 @@ func (a *Agent) SetBytesV(k string, v []byte) *Agent {
// SetBytesKV sets the given 'key: value' header.
//
// Use AddBytesKV for setting multiple header values under the same key.
func (a *Agent) SetBytesKV(k []byte, v []byte) *Agent {
func (a *Agent) SetBytesKV(k, v []byte) *Agent {
a.req.Header.SetBytesKV(k, v)
return a
@ -281,7 +283,7 @@ func (a *Agent) AddBytesV(k string, v []byte) *Agent {
//
// Multiple headers with the same key may be added with this function.
// Use SetBytesKV for setting a single header for the given key.
func (a *Agent) AddBytesKV(k []byte, v []byte) *Agent {
func (a *Agent) AddBytesKV(k, v []byte) *Agent {
a.req.Header.AddBytesKV(k, v)
return a
@ -652,10 +654,8 @@ func (a *Agent) Reuse() *Agent {
// certificate chain and host name.
func (a *Agent) InsecureSkipVerify() *Agent {
if a.HostClient.TLSConfig == nil {
/* #nosec G402 */
a.HostClient.TLSConfig = &tls.Config{InsecureSkipVerify: true} // #nosec G402
a.HostClient.TLSConfig = &tls.Config{InsecureSkipVerify: true} //nolint:gosec // We explicitly let the user set insecure mode here
} else {
/* #nosec G402 */
a.HostClient.TLSConfig.InsecureSkipVerify = true
}
@ -728,14 +728,14 @@ func (a *Agent) RetryIf(retryIf RetryIfFunc) *Agent {
// Bytes returns the status code, bytes body and errors of url.
//
// it's not safe to use Agent after calling [Agent.Bytes]
func (a *Agent) Bytes() (code int, body []byte, errs []error) {
func (a *Agent) Bytes() (int, []byte, []error) {
defer a.release()
return a.bytes()
}
func (a *Agent) bytes() (code int, body []byte, errs []error) {
func (a *Agent) bytes() (code int, body []byte, errs []error) { //nolint:nonamedreturns,revive // We want to overwrite the body in a deferred func. TODO: Check if we really need to do this. We eventually want to get rid of all named returns.
if errs = append(errs, a.errs...); len(errs) > 0 {
return
return code, body, errs
}
var (
@ -760,7 +760,7 @@ func (a *Agent) bytes() (code int, body []byte, errs []error) {
code = resp.StatusCode()
}
body = append(a.dest, resp.Body()...)
body = append(a.dest, resp.Body()...) //nolint:gocritic // We want to append to the returned slice here
if nilResp {
ReleaseResponse(resp)
@ -770,25 +770,25 @@ func (a *Agent) bytes() (code int, body []byte, errs []error) {
if a.timeout > 0 {
if err := a.HostClient.DoTimeout(req, resp, a.timeout); err != nil {
errs = append(errs, err)
return
return code, body, errs
}
} else if a.maxRedirectsCount > 0 && (string(req.Header.Method()) == MethodGet || string(req.Header.Method()) == MethodHead) {
if err := a.HostClient.DoRedirects(req, resp, a.maxRedirectsCount); err != nil {
errs = append(errs, err)
return
return code, body, errs
}
} else if err := a.HostClient.Do(req, resp); err != nil {
errs = append(errs, err)
}
return
return code, body, errs
}
func printDebugInfo(req *Request, resp *Response, w io.Writer) {
msg := fmt.Sprintf("Connected to %s(%s)\r\n\r\n", req.URI().Host(), resp.RemoteAddr())
_, _ = w.Write(utils.UnsafeBytes(msg))
_, _ = req.WriteTo(w)
_, _ = resp.WriteTo(w)
_, _ = w.Write(utils.UnsafeBytes(msg)) //nolint:errcheck // This will never fail
_, _ = req.WriteTo(w) //nolint:errcheck // This will never fail
_, _ = resp.WriteTo(w) //nolint:errcheck // This will never fail
}
// String returns the status code, string body and errors of url.
@ -797,6 +797,7 @@ func printDebugInfo(req *Request, resp *Response, w io.Writer) {
func (a *Agent) String() (int, string, []error) {
defer a.release()
code, body, errs := a.bytes()
// TODO: There might be a data race here on body. Maybe use utils.CopyBytes on it?
return code, utils.UnsafeString(body), errs
}
@ -805,12 +806,15 @@ func (a *Agent) String() (int, string, []error) {
// And bytes body will be unmarshalled to given v.
//
// it's not safe to use Agent after calling [Agent.Struct]
func (a *Agent) Struct(v interface{}) (code int, body []byte, errs []error) {
func (a *Agent) Struct(v interface{}) (int, []byte, []error) {
defer a.release()
if code, body, errs = a.bytes(); len(errs) > 0 {
return
code, body, errs := a.bytes()
if len(errs) > 0 {
return code, body, errs
}
// TODO: This should only be done once
if a.jsonDecoder == nil {
a.jsonDecoder = json.Unmarshal
}
@ -819,7 +823,7 @@ func (a *Agent) Struct(v interface{}) (code int, body []byte, errs []error) {
errs = append(errs, err)
}
return
return code, body, errs
}
func (a *Agent) release() {
@ -855,6 +859,7 @@ func (a *Agent) reset() {
a.formFiles = a.formFiles[:0]
}
//nolint:gochecknoglobals // TODO: Do not use global vars here
var (
clientPool sync.Pool
agentPool = sync.Pool{
@ -877,7 +882,11 @@ func AcquireClient() *Client {
if v == nil {
return &Client{}
}
return v.(*Client)
c, ok := v.(*Client)
if !ok {
panic(fmt.Errorf("failed to type-assert to *Client"))
}
return c
}
// ReleaseClient returns c acquired via AcquireClient to client pool.
@ -899,7 +908,11 @@ func ReleaseClient(c *Client) {
// no longer needed. This allows Agent recycling, reduces GC pressure
// and usually improves performance.
func AcquireAgent() *Agent {
return agentPool.Get().(*Agent)
a, ok := agentPool.Get().(*Agent)
if !ok {
panic(fmt.Errorf("failed to type-assert to *Agent"))
}
return a
}
// ReleaseAgent returns a acquired via AcquireAgent to Agent pool.
@ -922,7 +935,11 @@ func AcquireResponse() *Response {
if v == nil {
return &Response{}
}
return v.(*Response)
r, ok := v.(*Response)
if !ok {
panic(fmt.Errorf("failed to type-assert to *Response"))
}
return r
}
// ReleaseResponse return resp acquired via AcquireResponse to response pool.
@ -945,7 +962,11 @@ func AcquireArgs() *Args {
if v == nil {
return &Args{}
}
return v.(*Args)
a, ok := v.(*Args)
if !ok {
panic(fmt.Errorf("failed to type-assert to *Args"))
}
return a
}
// ReleaseArgs returns the object acquired via AcquireArgs to the pool.
@ -966,7 +987,11 @@ func AcquireFormFile() *FormFile {
if v == nil {
return &FormFile{}
}
return v.(*FormFile)
ff, ok := v.(*FormFile)
if !ok {
panic(fmt.Errorf("failed to type-assert to *FormFile"))
}
return ff
}
// ReleaseFormFile returns the object acquired via AcquireFormFile to the pool.
@ -981,9 +1006,7 @@ func ReleaseFormFile(ff *FormFile) {
formFilePool.Put(ff)
}
var (
strHTTP = []byte("http")
strHTTPS = []byte("https")
const (
defaultUserAgent = "fiber"
)

View File

@ -1,3 +1,4 @@
//nolint:wrapcheck // We must not wrap errors in tests
package fiber
import (
@ -19,6 +20,7 @@ import (
"github.com/gofiber/fiber/v2/internal/tlstest"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp/fasthttputil"
)
@ -295,8 +297,10 @@ func Test_Client_Agent_Set_Or_Add_Headers(t *testing.T) {
handler := func(c *Ctx) error {
c.Request().Header.VisitAll(func(key, value []byte) {
if k := string(key); k == "K1" || k == "K2" {
_, _ = c.Write(key)
_, _ = c.Write(value)
_, err := c.Write(key)
utils.AssertEqual(t, nil, err)
_, err = c.Write(value)
utils.AssertEqual(t, nil, err)
}
})
return nil
@ -581,25 +585,26 @@ type readErrorConn struct {
net.Conn
}
func (r *readErrorConn) Read(p []byte) (int, error) {
func (*readErrorConn) Read(_ []byte) (int, error) {
return 0, fmt.Errorf("error")
}
func (r *readErrorConn) Write(p []byte) (int, error) {
func (*readErrorConn) Write(p []byte) (int, error) {
return len(p), nil
}
func (r *readErrorConn) Close() error {
func (*readErrorConn) Close() error {
return nil
}
func (r *readErrorConn) LocalAddr() net.Addr {
func (*readErrorConn) LocalAddr() net.Addr {
return nil
}
func (r *readErrorConn) RemoteAddr() net.Addr {
func (*readErrorConn) RemoteAddr() net.Addr {
return nil
}
func Test_Client_Agent_RetryIf(t *testing.T) {
t.Parallel()
@ -783,7 +788,10 @@ func Test_Client_Agent_MultipartForm_SendFiles(t *testing.T) {
buf := make([]byte, fh1.Size)
f, err := fh1.Open()
utils.AssertEqual(t, nil, err)
defer func() { _ = f.Close() }()
defer func() {
err := f.Close()
utils.AssertEqual(t, nil, err)
}()
_, err = f.Read(buf)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "form file", string(buf))
@ -831,13 +839,16 @@ func checkFormFile(t *testing.T, fh *multipart.FileHeader, filename string) {
basename := filepath.Base(filename)
utils.AssertEqual(t, fh.Filename, basename)
b1, err := os.ReadFile(filename)
b1, err := os.ReadFile(filename) //nolint:gosec // We're in a test so reading user-provided files by name is fine
utils.AssertEqual(t, nil, err)
b2 := make([]byte, fh.Size)
f, err := fh.Open()
utils.AssertEqual(t, nil, err)
defer func() { _ = f.Close() }()
defer func() {
err := f.Close()
utils.AssertEqual(t, nil, err)
}()
_, err = f.Read(b2)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, b1, b2)
@ -962,6 +973,7 @@ func Test_Client_Agent_InsecureSkipVerify(t *testing.T) {
cer, err := tls.LoadX509KeyPair("./.github/testdata/ssl.pem", "./.github/testdata/ssl.key")
utils.AssertEqual(t, nil, err)
//nolint:gosec // We're in a test so using old ciphers is fine
serverTLSConf := &tls.Config{
Certificates: []tls.Certificate{cer},
}
@ -1137,7 +1149,7 @@ func Test_Client_Agent_Struct(t *testing.T) {
defer ReleaseAgent(a)
defer a.ConnectionClose()
request := a.Request()
request.Header.SetMethod("GET")
request.Header.SetMethod(MethodGet)
request.SetRequestURI("http://example.com")
err := a.Parse()
utils.AssertEqual(t, nil, err)
@ -1198,8 +1210,8 @@ type errorMultipartWriter struct {
count int
}
func (e *errorMultipartWriter) Boundary() string { return "myBoundary" }
func (e *errorMultipartWriter) SetBoundary(_ string) error { return nil }
func (*errorMultipartWriter) Boundary() string { return "myBoundary" }
func (*errorMultipartWriter) SetBoundary(_ string) error { return nil }
func (e *errorMultipartWriter) CreateFormFile(_, _ string) (io.Writer, error) {
if e.count == 0 {
e.count++
@ -1207,8 +1219,8 @@ func (e *errorMultipartWriter) CreateFormFile(_, _ string) (io.Writer, error) {
}
return errorWriter{}, nil
}
func (e *errorMultipartWriter) WriteField(_, _ string) error { return errors.New("WriteField error") }
func (e *errorMultipartWriter) Close() error { return errors.New("Close error") }
func (*errorMultipartWriter) WriteField(_, _ string) error { return errors.New("WriteField error") }
func (*errorMultipartWriter) Close() error { return errors.New("Close error") }
type errorWriter struct{}

View File

@ -53,6 +53,8 @@ type Colors struct {
}
// DefaultColors Default color codes
//
//nolint:gochecknoglobals // Using a global var is fine here
var DefaultColors = Colors{
Black: "\u001b[90m",
Red: "\u001b[91m",

188
ctx.go
View File

@ -27,10 +27,16 @@ import (
"github.com/gofiber/fiber/v2/internal/dictpool"
"github.com/gofiber/fiber/v2/internal/schema"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp"
)
const (
schemeHTTP = "http"
schemeHTTPS = "https"
)
// maxParams defines the maximum number of parameters per route.
const maxParams = 30
@ -45,6 +51,7 @@ const (
// userContextKey define the key name for storing context.Context in *fasthttp.RequestCtx
const userContextKey = "__local_user_context__"
//nolint:gochecknoglobals // TODO: Do not use global vars here
var (
// decoderPoolMap helps to improve BodyParser's, QueryParser's and ReqHeaderParser's performance
decoderPoolMap = map[string]*sync.Pool{}
@ -52,6 +59,7 @@ var (
tags = []string{queryTag, bodyTag, reqHeaderTag, paramsTag}
)
//nolint:gochecknoinits // init() is used to initialize a global map variable
func init() {
for _, tag := range tags {
decoderPoolMap[tag] = &sync.Pool{New: func() interface{} {
@ -100,9 +108,10 @@ type TLSHandler struct {
}
// GetClientInfo Callback function to set CHI
// TODO: Why is this a getter which sets stuff?
func (t *TLSHandler) GetClientInfo(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
t.clientHelloInfo = info
return nil, nil
return nil, nil //nolint:nilnil // Not returning anything useful here is probably fine
}
// Range data for c.Range
@ -151,7 +160,10 @@ type ParserConfig struct {
// AcquireCtx retrieves a new Ctx from the pool.
func (app *App) AcquireCtx(fctx *fasthttp.RequestCtx) *Ctx {
c := app.pool.Get().(*Ctx)
c, ok := app.pool.Get().(*Ctx)
if !ok {
panic(fmt.Errorf("failed to type-assert to *Ctx"))
}
// Set app reference
c.app = app
// Reset route and handler index
@ -388,7 +400,6 @@ func (c *Ctx) BodyParser(out interface{}) error {
} else {
data[k] = append(data[k], v)
}
})
return c.parseToStruct(bodyTag, out, data)
@ -401,7 +412,10 @@ func (c *Ctx) BodyParser(out interface{}) error {
return c.parseToStruct(bodyTag, out, data.Value)
}
if strings.HasPrefix(ctype, MIMETextXML) || strings.HasPrefix(ctype, MIMEApplicationXML) {
return xml.Unmarshal(c.Body(), out)
if err := xml.Unmarshal(c.Body(), out); err != nil {
return fmt.Errorf("failed to unmarshal: %w", err)
}
return nil
}
// No suitable content type found
return ErrUnprocessableEntity
@ -673,8 +687,11 @@ func (c *Ctx) Hostname() string {
// Port returns the remote port of the request.
func (c *Ctx) Port() string {
port := c.fasthttp.RemoteAddr().(*net.TCPAddr).Port
return strconv.Itoa(port)
tcpaddr, ok := c.fasthttp.RemoteAddr().(*net.TCPAddr)
if !ok {
panic(fmt.Errorf("failed to type-assert to *net.TCPAddr"))
}
return strconv.Itoa(tcpaddr.Port)
}
// IP returns the remote IP address of the request.
@ -691,13 +708,16 @@ func (c *Ctx) IP() string {
// extractIPsFromHeader will return a slice of IPs it found given a header name in the order they appear.
// When IP validation is enabled, any invalid IPs will be omitted.
func (c *Ctx) extractIPsFromHeader(header string) []string {
// TODO: Reuse the c.extractIPFromHeader func somehow in here
headerValue := c.Get(header)
// We can't know how many IPs we will return, but we will try to guess with this constant division.
// Counting ',' makes function slower for about 50ns in general case.
estimatedCount := len(headerValue) / 8
if estimatedCount > 8 {
estimatedCount = 8 // Avoid big allocation on big header
const maxEstimatedCount = 8
estimatedCount := len(headerValue) / maxEstimatedCount
if estimatedCount > maxEstimatedCount {
estimatedCount = maxEstimatedCount // Avoid big allocation on big header
}
ipsFound := make([]string, 0, estimatedCount)
@ -707,11 +727,10 @@ func (c *Ctx) extractIPsFromHeader(header string) []string {
iploop:
for {
v4 := false
v6 := false
var v4, v6 bool
// Manually splitting string without allocating slice, working with parts directly
i, j = j+1, j+2
i, j = j+1, j+2 //nolint:gomnd // Using these values is fine
if j > len(headerValue) {
break
@ -758,9 +777,10 @@ func (c *Ctx) extractIPFromHeader(header string) string {
iploop:
for {
v4 := false
v6 := false
i, j = j+1, j+2
var v4, v6 bool
// Manually splitting string without allocating slice, working with parts directly
i, j = j+1, j+2 //nolint:gomnd // Using these values is fine
if j > len(headerValue) {
break
@ -793,14 +813,14 @@ func (c *Ctx) extractIPFromHeader(header string) string {
return c.fasthttp.RemoteIP().String()
}
// default behaviour if IP validation is not enabled is just to return whatever value is
// default behavior if IP validation is not enabled is just to return whatever value is
// in the proxy header. Even if it is empty or invalid
return c.Get(c.app.config.ProxyHeader)
}
// IPs returns a string slice of IP addresses specified in the X-Forwarded-For request header.
// When IP validation is enabled, only valid IPs are returned.
func (c *Ctx) IPs() (ips []string) {
func (c *Ctx) IPs() []string {
return c.extractIPsFromHeader(HeaderXForwardedFor)
}
@ -839,7 +859,7 @@ func (c *Ctx) JSON(data interface{}) error {
func (c *Ctx) JSONP(data interface{}, callback ...string) error {
raw, err := json.Marshal(data)
if err != nil {
return err
return fmt.Errorf("failed to marshal: %w", err)
}
var result, cb string
@ -877,11 +897,11 @@ func (c *Ctx) Links(link ...string) {
bb := bytebufferpool.Get()
for i := range link {
if i%2 == 0 {
_ = bb.WriteByte('<')
_, _ = bb.WriteString(link[i])
_ = bb.WriteByte('>')
_ = bb.WriteByte('<') //nolint:errcheck // This will never fail
_, _ = bb.WriteString(link[i]) //nolint:errcheck // This will never fail
_ = bb.WriteByte('>') //nolint:errcheck // This will never fail
} else {
_, _ = bb.WriteString(`; rel="` + link[i] + `",`)
_, _ = bb.WriteString(`; rel="` + link[i] + `",`) //nolint:errcheck // This will never fail
}
}
c.setCanonical(HeaderLink, utils.TrimRight(c.app.getString(bb.Bytes()), ','))
@ -890,7 +910,7 @@ func (c *Ctx) Links(link ...string) {
// Locals makes it possible to pass interface{} values under keys scoped to the request
// and therefore available to all following routes that match the request.
func (c *Ctx) Locals(key interface{}, value ...interface{}) (val interface{}) {
func (c *Ctx) Locals(key interface{}, value ...interface{}) interface{} {
if len(value) == 0 {
return c.fasthttp.UserValue(key)
}
@ -933,9 +953,10 @@ func (c *Ctx) ClientHelloInfo() *tls.ClientHelloInfo {
}
// Next executes the next method in the stack that matches the current route.
func (c *Ctx) Next() (err error) {
func (c *Ctx) Next() error {
// Increment handler index
c.indexHandler++
var err error
// Did we executed all route handlers?
if c.indexHandler < len(c.route.Handlers) {
// Continue route stack
@ -947,7 +968,7 @@ func (c *Ctx) Next() (err error) {
return err
}
// RestartRouting instead of going to the next handler. This may be usefull after
// RestartRouting instead of going to the next handler. This may be useful after
// changing the request path. Note that handlers might be executed again.
func (c *Ctx) RestartRouting() error {
c.indexRoute = -1
@ -1017,9 +1038,8 @@ func (c *Ctx) ParamsInt(key string, defaultValue ...int) (int, error) {
if err != nil {
if len(defaultValue) > 0 {
return defaultValue[0], nil
} else {
return 0, err
}
return 0, fmt.Errorf("failed to convert: %w", err)
}
return value, nil
@ -1044,15 +1064,16 @@ func (c *Ctx) Path(override ...string) string {
// Please use Config.EnableTrustedProxyCheck to prevent header spoofing, in case when your app is behind the proxy.
func (c *Ctx) Protocol() string {
if c.fasthttp.IsTLS() {
return "https"
return schemeHTTPS
}
if !c.IsProxyTrusted() {
return "http"
return schemeHTTP
}
scheme := "http"
scheme := schemeHTTP
const lenXHeaderName = 12
c.fasthttp.Request.Header.VisitAll(func(key, val []byte) {
if len(key) < 12 {
if len(key) < lenXHeaderName {
return // Neither "X-Forwarded-" nor "X-Url-Scheme"
}
switch {
@ -1067,7 +1088,7 @@ func (c *Ctx) Protocol() string {
scheme = v
}
} else if bytes.Equal(key, []byte(HeaderXForwardedSsl)) && bytes.Equal(val, []byte("on")) {
scheme = "https"
scheme = schemeHTTPS
}
case bytes.Equal(key, []byte(HeaderXUrlScheme)):
@ -1100,9 +1121,8 @@ func (c *Ctx) QueryInt(key string, defaultValue ...int) int {
if err != nil {
if len(defaultValue) > 0 {
return defaultValue[0]
} else {
return 0
}
return 0
}
return value
@ -1133,7 +1153,6 @@ func (c *Ctx) QueryParser(out interface{}) error {
} else {
data[k] = append(data[k], v)
}
})
if err != nil {
@ -1150,10 +1169,9 @@ func parseParamSquareBrackets(k string) (string, error) {
kbytes := []byte(k)
for i, b := range kbytes {
if b == '[' && kbytes[i+1] != ']' {
if err := bb.WriteByte('.'); err != nil {
return "", err
return "", fmt.Errorf("failed to write: %w", err)
}
}
@ -1162,7 +1180,7 @@ func parseParamSquareBrackets(k string) (string, error) {
}
if err := bb.WriteByte(b); err != nil {
return "", err
return "", fmt.Errorf("failed to write: %w", err)
}
}
@ -1184,21 +1202,27 @@ func (c *Ctx) ReqHeaderParser(out interface{}) error {
} else {
data[k] = append(data[k], v)
}
})
return c.parseToStruct(reqHeaderTag, out, data)
}
func (c *Ctx) parseToStruct(aliasTag string, out interface{}, data map[string][]string) error {
func (*Ctx) parseToStruct(aliasTag string, out interface{}, data map[string][]string) error {
// Get decoder from pool
schemaDecoder := decoderPoolMap[aliasTag].Get().(*schema.Decoder)
schemaDecoder, ok := decoderPoolMap[aliasTag].Get().(*schema.Decoder)
if !ok {
panic(fmt.Errorf("failed to type-assert to *schema.Decoder"))
}
defer decoderPoolMap[aliasTag].Put(schemaDecoder)
// Set alias tag
schemaDecoder.SetAliasTag(aliasTag)
return schemaDecoder.Decode(out, data)
if err := schemaDecoder.Decode(out, data); err != nil {
return fmt.Errorf("failed to decode: %w", err)
}
return nil
}
func equalFieldType(out interface{}, kind reflect.Kind, key string) bool {
@ -1248,24 +1272,23 @@ var (
)
// Range returns a struct containing the type and a slice of ranges.
func (c *Ctx) Range(size int) (rangeData Range, err error) {
func (c *Ctx) Range(size int) (Range, error) {
var rangeData Range
rangeStr := c.Get(HeaderRange)
if rangeStr == "" || !strings.Contains(rangeStr, "=") {
err = ErrRangeMalformed
return
return rangeData, ErrRangeMalformed
}
data := strings.Split(rangeStr, "=")
if len(data) != 2 {
err = ErrRangeMalformed
return
const expectedDataParts = 2
if len(data) != expectedDataParts {
return rangeData, ErrRangeMalformed
}
rangeData.Type = data[0]
arr := strings.Split(data[1], ",")
for i := 0; i < len(arr); i++ {
item := strings.Split(arr[i], "-")
if len(item) == 1 {
err = ErrRangeMalformed
return
return rangeData, ErrRangeMalformed
}
start, startErr := strconv.Atoi(item[0])
end, endErr := strconv.Atoi(item[1])
@ -1290,11 +1313,10 @@ func (c *Ctx) Range(size int) (rangeData Range, err error) {
})
}
if len(rangeData.Ranges) < 1 {
err = ErrRangeUnsatisfiable
return
return rangeData, ErrRangeUnsatisfiable
}
return
return rangeData, nil
}
// Redirect to the URL derived from the specified path, with specified status.
@ -1330,7 +1352,7 @@ func (c *Ctx) getLocationFromRoute(route Route, params Map) (string, error) {
if !segment.IsParam {
_, err := buf.WriteString(segment.Const)
if err != nil {
return "", err
return "", fmt.Errorf("failed to write string: %w", err)
}
continue
}
@ -1341,7 +1363,7 @@ func (c *Ctx) getLocationFromRoute(route Route, params Map) (string, error) {
if isSame || isGreedy {
_, err := buf.WriteString(utils.ToString(val))
if err != nil {
return "", err
return "", fmt.Errorf("failed to write string: %w", err)
}
}
}
@ -1373,10 +1395,10 @@ func (c *Ctx) RedirectToRoute(routeName string, params Map, status ...int) error
i := 1
for k, v := range queries {
_, _ = queryText.WriteString(k + "=" + v)
_, _ = queryText.WriteString(k + "=" + v) //nolint:errcheck // This will never fail
if i != len(queries) {
_, _ = queryText.WriteString("&")
_, _ = queryText.WriteString("&") //nolint:errcheck // This will never fail
}
i++
}
@ -1399,7 +1421,6 @@ func (c *Ctx) RedirectBack(fallback string, status ...int) error {
// Render a template with data and sends a text/html response.
// We support the following engines: html, amber, handlebars, mustache, pug
func (c *Ctx) Render(name string, bind interface{}, layouts ...string) error {
var err error
// Get new buffer from pool
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
@ -1421,7 +1442,7 @@ func (c *Ctx) Render(name string, bind interface{}, layouts ...string) error {
// Render template from Views
if app.config.Views != nil {
if err := app.config.Views.Render(buf, name, bind, layouts...); err != nil {
return err
return fmt.Errorf("failed to render: %w", err)
}
rendered = true
@ -1433,17 +1454,18 @@ func (c *Ctx) Render(name string, bind interface{}, layouts ...string) error {
if !rendered {
// Render raw template using 'name' as filepath if no engine is set
var tmpl *template.Template
if _, err = readContent(buf, name); err != nil {
if _, err := readContent(buf, name); err != nil {
return err
}
// Parse template
if tmpl, err = template.New("").Parse(c.app.getString(buf.Bytes())); err != nil {
return err
tmpl, err := template.New("").Parse(c.app.getString(buf.Bytes()))
if err != nil {
return fmt.Errorf("failed to parse: %w", err)
}
buf.Reset()
// Render template
if err = tmpl.Execute(buf, bind); err != nil {
return err
if err := tmpl.Execute(buf, bind); err != nil {
return fmt.Errorf("failed to execute: %w", err)
}
}
@ -1451,8 +1473,8 @@ func (c *Ctx) Render(name string, bind interface{}, layouts ...string) error {
c.fasthttp.Response.Header.SetContentType(MIMETextHTMLCharsetUTF8)
// Set rendered template to body
c.fasthttp.Response.SetBody(buf.Bytes())
// Return err if exist
return err
return nil
}
func (c *Ctx) renderExtensions(bind interface{}) {
@ -1501,28 +1523,32 @@ func (c *Ctx) Route() *Route {
}
// SaveFile saves any multipart file to disk.
func (c *Ctx) SaveFile(fileheader *multipart.FileHeader, path string) error {
func (*Ctx) SaveFile(fileheader *multipart.FileHeader, path string) error {
return fasthttp.SaveMultipartFile(fileheader, path)
}
// SaveFileToStorage saves any multipart file to an external storage system.
func (c *Ctx) SaveFileToStorage(fileheader *multipart.FileHeader, path string, storage Storage) error {
func (*Ctx) SaveFileToStorage(fileheader *multipart.FileHeader, path string, storage Storage) error {
file, err := fileheader.Open()
if err != nil {
return err
return fmt.Errorf("failed to open: %w", err)
}
content, err := io.ReadAll(file)
if err != nil {
return err
return fmt.Errorf("failed to read: %w", err)
}
return storage.Set(path, content, 0)
if err := storage.Set(path, content, 0); err != nil {
return fmt.Errorf("failed to store: %w", err)
}
return nil
}
// Secure returns whether a secure connection was established.
func (c *Ctx) Secure() bool {
return c.Protocol() == "https"
return c.Protocol() == schemeHTTPS
}
// Send sets the HTTP response body without copying it.
@ -1533,6 +1559,7 @@ func (c *Ctx) Send(body []byte) error {
return nil
}
//nolint:gochecknoglobals // TODO: Do not use global vars here
var (
sendFileOnce sync.Once
sendFileFS *fasthttp.FS
@ -1548,6 +1575,7 @@ func (c *Ctx) SendFile(file string, compress ...bool) error {
// https://github.com/valyala/fasthttp/blob/c7576cc10cabfc9c993317a2d3f8355497bea156/fs.go#L129-L134
sendFileOnce.Do(func() {
const cacheDuration = 10 * time.Second
sendFileFS = &fasthttp.FS{
Root: "",
AllowEmptyRoot: true,
@ -1555,7 +1583,7 @@ func (c *Ctx) SendFile(file string, compress ...bool) error {
AcceptByteRange: true,
Compress: true,
CompressedFileSuffix: c.app.config.CompressedFileSuffix,
CacheDuration: 10 * time.Second,
CacheDuration: cacheDuration,
IndexNames: []string{"index.html"},
PathNotFound: func(ctx *fasthttp.RequestCtx) {
ctx.Response.SetStatusCode(StatusNotFound)
@ -1579,7 +1607,7 @@ func (c *Ctx) SendFile(file string, compress ...bool) error {
var err error
file = filepath.FromSlash(file)
if file, err = filepath.Abs(file); err != nil {
return err
return fmt.Errorf("failed to determine abs file path: %w", err)
}
if hasTrailingSlash {
file += "/"
@ -1644,11 +1672,11 @@ func (c *Ctx) SendStream(stream io.Reader, size ...int) error {
}
// Set sets the response's HTTP header field to the specified key, value.
func (c *Ctx) Set(key string, val string) {
func (c *Ctx) Set(key, val string) {
c.fasthttp.Response.Header.Set(key, val)
}
func (c *Ctx) setCanonical(key string, val string) {
func (c *Ctx) setCanonical(key, val string) {
c.fasthttp.Response.Header.SetCanonical(c.app.getBytes(key), c.app.getBytes(val))
}
@ -1719,6 +1747,7 @@ func (c *Ctx) Write(p []byte) (int, error) {
// Writef appends f & a into response body writer.
func (c *Ctx) Writef(f string, a ...interface{}) (int, error) {
//nolint:wrapcheck // This must not be wrapped
return fmt.Fprintf(c.fasthttp.Response.BodyWriter(), f, a...)
}
@ -1760,8 +1789,9 @@ func (c *Ctx) configDependentPaths() {
// Define the path for dividing routes into areas for fast tree detection, so that fewer routes need to be traversed,
// since the first three characters area select a list of routes
c.treePath = c.treePath[0:0]
if len(c.detectionPath) >= 3 {
c.treePath = c.detectionPath[:3]
const maxDetectionPaths = 3
if len(c.detectionPath) >= maxDetectionPaths {
c.treePath = c.detectionPath[:maxDetectionPaths]
}
}
@ -1786,7 +1816,7 @@ func (c *Ctx) IsProxyTrusted() bool {
}
// IsLocalHost will return true if address is a localhost address.
func (c *Ctx) isLocalHost(address string) bool {
func (*Ctx) isLocalHost(address string) bool {
localHosts := []string{"127.0.0.1", "0.0.0.0", "::1"}
for _, h := range localHosts {
if strings.Contains(address, h) {

View File

@ -2,6 +2,7 @@
// 🤖 Github Repository: https://github.com/gofiber/fiber
// 📌 API Documentation: https://docs.gofiber.io
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
package fiber
import (
@ -28,6 +29,7 @@ import (
"github.com/gofiber/fiber/v2/internal/storage/memory"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp"
)
@ -361,8 +363,10 @@ func Test_Ctx_BodyParser(t *testing.T) {
{
var gzipJSON bytes.Buffer
w := gzip.NewWriter(&gzipJSON)
_, _ = w.Write([]byte(`{"name":"john"}`))
_ = w.Close()
_, err := w.Write([]byte(`{"name":"john"}`))
utils.AssertEqual(t, nil, err)
err = w.Close()
utils.AssertEqual(t, nil, err)
c.Request().Header.SetContentType(MIMEApplicationJSON)
c.Request().Header.Set(HeaderContentEncoding, "gzip")
@ -431,9 +435,7 @@ func Test_Ctx_ParamParser(t *testing.T) {
UserID uint `params:"userId"`
RoleID uint `params:"roleId"`
}
var (
d = new(Demo)
)
d := new(Demo)
if err := ctx.ParamsParser(d); err != nil {
t.Fatal(err)
}
@ -519,7 +521,7 @@ func Benchmark_Ctx_BodyParser_JSON(b *testing.B) {
b.ResetTimer()
for n := 0; n < b.N; n++ {
_ = c.BodyParser(d)
_ = c.BodyParser(d) //nolint:errcheck // It is fine to ignore the error here as we check it once further below
}
utils.AssertEqual(b, nil, c.BodyParser(d))
utils.AssertEqual(b, "john", d.Name)
@ -543,7 +545,7 @@ func Benchmark_Ctx_BodyParser_XML(b *testing.B) {
b.ResetTimer()
for n := 0; n < b.N; n++ {
_ = c.BodyParser(d)
_ = c.BodyParser(d) //nolint:errcheck // It is fine to ignore the error here as we check it once further below
}
utils.AssertEqual(b, nil, c.BodyParser(d))
utils.AssertEqual(b, "john", d.Name)
@ -567,7 +569,7 @@ func Benchmark_Ctx_BodyParser_Form(b *testing.B) {
b.ResetTimer()
for n := 0; n < b.N; n++ {
_ = c.BodyParser(d)
_ = c.BodyParser(d) //nolint:errcheck // It is fine to ignore the error here as we check it once further below
}
utils.AssertEqual(b, nil, c.BodyParser(d))
utils.AssertEqual(b, "john", d.Name)
@ -592,7 +594,7 @@ func Benchmark_Ctx_BodyParser_MultipartForm(b *testing.B) {
b.ResetTimer()
for n := 0; n < b.N; n++ {
_ = c.BodyParser(d)
_ = c.BodyParser(d) //nolint:errcheck // It is fine to ignore the error here as we check it once further below
}
utils.AssertEqual(b, nil, c.BodyParser(d))
utils.AssertEqual(b, "john", d.Name)
@ -879,12 +881,13 @@ func Test_Ctx_FormFile(t *testing.T) {
f, err := fh.Open()
utils.AssertEqual(t, nil, err)
defer func() {
utils.AssertEqual(t, nil, f.Close())
}()
b := new(bytes.Buffer)
_, err = io.Copy(b, f)
utils.AssertEqual(t, nil, err)
f.Close()
utils.AssertEqual(t, "hello world", b.String())
return nil
})
@ -897,8 +900,7 @@ func Test_Ctx_FormFile(t *testing.T) {
_, err = ioWriter.Write([]byte("hello world"))
utils.AssertEqual(t, nil, err)
writer.Close()
utils.AssertEqual(t, nil, writer.Close())
req := httptest.NewRequest(MethodPost, "/test", body)
req.Header.Set(HeaderContentType, writer.FormDataContentType())
@ -921,10 +923,9 @@ func Test_Ctx_FormValue(t *testing.T) {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
utils.AssertEqual(t, nil, writer.WriteField("name", "john"))
utils.AssertEqual(t, nil, writer.Close())
writer.Close()
req := httptest.NewRequest(MethodPost, "/test", body)
req.Header.Set("Content-Type", fmt.Sprintf("multipart/form-data; boundary=%s", writer.Boundary()))
req.Header.Set("Content-Length", strconv.Itoa(len(body.Bytes())))
@ -1240,7 +1241,7 @@ func Test_Ctx_IP(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
// default behaviour will return the remote IP from the stack
// default behavior will return the remote IP from the stack
utils.AssertEqual(t, "0.0.0.0", c.IP())
// X-Forwarded-For is set, but it is ignored because proxyHeader is not set
@ -1252,7 +1253,7 @@ func Test_Ctx_IP(t *testing.T) {
func Test_Ctx_IP_ProxyHeader(t *testing.T) {
t.Parallel()
// make sure that the same behaviour exists for different proxy header names
// make sure that the same behavior exists for different proxy header names
proxyHeaderNames := []string{"Real-Ip", HeaderXForwardedFor}
for _, proxyHeaderName := range proxyHeaderNames {
@ -1286,7 +1287,7 @@ func Test_Ctx_IP_ProxyHeader(t *testing.T) {
func Test_Ctx_IP_ProxyHeader_With_IP_Validation(t *testing.T) {
t.Parallel()
// make sure that the same behaviour exists for different proxy header names
// make sure that the same behavior exists for different proxy header names
proxyHeaderNames := []string{"Real-Ip", HeaderXForwardedFor}
for _, proxyHeaderName := range proxyHeaderNames {
@ -1625,35 +1626,43 @@ func Test_Ctx_ClientHelloInfo(t *testing.T) {
})
// Test without TLS handler
resp, _ := app.Test(httptest.NewRequest(MethodGet, "/ServerName", nil))
body, _ := io.ReadAll(resp.Body)
resp, err := app.Test(httptest.NewRequest(MethodGet, "/ServerName", nil))
utils.AssertEqual(t, nil, err)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, []byte("ClientHelloInfo is nil"), body)
// Test with TLS Handler
const (
PSSWithSHA256 = 0x0804
VersionTLS13 = 0x0304
pssWithSHA256 = 0x0804
versionTLS13 = 0x0304
)
app.tlsHandler = &TLSHandler{clientHelloInfo: &tls.ClientHelloInfo{
ServerName: "example.golang",
SignatureSchemes: []tls.SignatureScheme{PSSWithSHA256},
SupportedVersions: []uint16{VersionTLS13},
SignatureSchemes: []tls.SignatureScheme{pssWithSHA256},
SupportedVersions: []uint16{versionTLS13},
}}
// Test ServerName
resp, _ = app.Test(httptest.NewRequest(MethodGet, "/ServerName", nil))
body, _ = io.ReadAll(resp.Body)
resp, err = app.Test(httptest.NewRequest(MethodGet, "/ServerName", nil))
utils.AssertEqual(t, nil, err)
body, err = io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, []byte("example.golang"), body)
// Test SignatureSchemes
resp, _ = app.Test(httptest.NewRequest(MethodGet, "/SignatureSchemes", nil))
body, _ = io.ReadAll(resp.Body)
utils.AssertEqual(t, "["+strconv.Itoa(PSSWithSHA256)+"]", string(body))
resp, err = app.Test(httptest.NewRequest(MethodGet, "/SignatureSchemes", nil))
utils.AssertEqual(t, nil, err)
body, err = io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "["+strconv.Itoa(pssWithSHA256)+"]", string(body))
// Test SupportedVersions
resp, _ = app.Test(httptest.NewRequest(MethodGet, "/SupportedVersions", nil))
body, _ = io.ReadAll(resp.Body)
utils.AssertEqual(t, "["+strconv.Itoa(VersionTLS13)+"]", string(body))
resp, err = app.Test(httptest.NewRequest(MethodGet, "/SupportedVersions", nil))
utils.AssertEqual(t, nil, err)
body, err = io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "["+strconv.Itoa(versionTLS13)+"]", string(body))
}
// go test -run Test_Ctx_InvalidMethod
@ -1688,10 +1697,9 @@ func Test_Ctx_MultipartForm(t *testing.T) {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
utils.AssertEqual(t, nil, writer.WriteField("name", "john"))
utils.AssertEqual(t, nil, writer.Close())
writer.Close()
req := httptest.NewRequest(MethodPost, "/test", body)
req.Header.Set(HeaderContentType, fmt.Sprintf("multipart/form-data; boundary=%s", writer.Boundary()))
req.Header.Set(HeaderContentLength, strconv.Itoa(len(body.Bytes())))
@ -1706,8 +1714,8 @@ func Benchmark_Ctx_MultipartForm(b *testing.B) {
app := New()
app.Post("/", func(c *Ctx) error {
_, _ = c.MultipartForm()
return nil
_, err := c.MultipartForm()
return err
})
c := &fasthttp.RequestCtx{}
@ -1889,11 +1897,16 @@ func Benchmark_Ctx_AllParams(b *testing.B) {
for n := 0; n < b.N; n++ {
res = c.AllParams()
}
utils.AssertEqual(b, map[string]string{"param1": "john",
"param2": "doe",
"param3": "is",
"param4": "awesome"},
res)
utils.AssertEqual(
b,
map[string]string{
"param1": "john",
"param2": "doe",
"param3": "is",
"param4": "awesome",
},
res,
)
}
// go test -v -run=^$ -bench=Benchmark_Ctx_ParamsParse -benchmem -count=4
@ -1964,31 +1977,31 @@ func Test_Ctx_Protocol(t *testing.T) {
c := app.AcquireCtx(freq)
defer app.ReleaseCtx(c)
c.Request().Header.Set(HeaderXForwardedProto, "https")
utils.AssertEqual(t, "https", c.Protocol())
c.Request().Header.Set(HeaderXForwardedProto, schemeHTTPS)
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
c.Request().Header.Reset()
c.Request().Header.Set(HeaderXForwardedProtocol, "https")
utils.AssertEqual(t, "https", c.Protocol())
c.Request().Header.Set(HeaderXForwardedProtocol, schemeHTTPS)
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
c.Request().Header.Reset()
c.Request().Header.Set(HeaderXForwardedProto, "https, http")
utils.AssertEqual(t, "https", c.Protocol())
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
c.Request().Header.Reset()
c.Request().Header.Set(HeaderXForwardedProtocol, "https, http")
utils.AssertEqual(t, "https", c.Protocol())
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
c.Request().Header.Reset()
c.Request().Header.Set(HeaderXForwardedSsl, "on")
utils.AssertEqual(t, "https", c.Protocol())
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
c.Request().Header.Reset()
c.Request().Header.Set(HeaderXUrlScheme, "https")
utils.AssertEqual(t, "https", c.Protocol())
c.Request().Header.Set(HeaderXUrlScheme, schemeHTTPS)
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
c.Request().Header.Reset()
utils.AssertEqual(t, "http", c.Protocol())
utils.AssertEqual(t, schemeHTTP, c.Protocol())
}
// go test -v -run=^$ -bench=Benchmark_Ctx_Protocol -benchmem -count=4
@ -2002,7 +2015,7 @@ func Benchmark_Ctx_Protocol(b *testing.B) {
for n := 0; n < b.N; n++ {
res = c.Protocol()
}
utils.AssertEqual(b, "http", res)
utils.AssertEqual(b, schemeHTTP, res)
}
// go test -run Test_Ctx_Protocol_TrustedProxy
@ -2012,23 +2025,23 @@ func Test_Ctx_Protocol_TrustedProxy(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.Request().Header.Set(HeaderXForwardedProto, "https")
utils.AssertEqual(t, "https", c.Protocol())
c.Request().Header.Set(HeaderXForwardedProto, schemeHTTPS)
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
c.Request().Header.Reset()
c.Request().Header.Set(HeaderXForwardedProtocol, "https")
utils.AssertEqual(t, "https", c.Protocol())
c.Request().Header.Set(HeaderXForwardedProtocol, schemeHTTPS)
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
c.Request().Header.Reset()
c.Request().Header.Set(HeaderXForwardedSsl, "on")
utils.AssertEqual(t, "https", c.Protocol())
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
c.Request().Header.Reset()
c.Request().Header.Set(HeaderXUrlScheme, "https")
utils.AssertEqual(t, "https", c.Protocol())
c.Request().Header.Set(HeaderXUrlScheme, schemeHTTPS)
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
c.Request().Header.Reset()
utils.AssertEqual(t, "http", c.Protocol())
utils.AssertEqual(t, schemeHTTP, c.Protocol())
}
// go test -run Test_Ctx_Protocol_TrustedProxyRange
@ -2038,23 +2051,23 @@ func Test_Ctx_Protocol_TrustedProxyRange(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.Request().Header.Set(HeaderXForwardedProto, "https")
utils.AssertEqual(t, "https", c.Protocol())
c.Request().Header.Set(HeaderXForwardedProto, schemeHTTPS)
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
c.Request().Header.Reset()
c.Request().Header.Set(HeaderXForwardedProtocol, "https")
utils.AssertEqual(t, "https", c.Protocol())
c.Request().Header.Set(HeaderXForwardedProtocol, schemeHTTPS)
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
c.Request().Header.Reset()
c.Request().Header.Set(HeaderXForwardedSsl, "on")
utils.AssertEqual(t, "https", c.Protocol())
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
c.Request().Header.Reset()
c.Request().Header.Set(HeaderXUrlScheme, "https")
utils.AssertEqual(t, "https", c.Protocol())
c.Request().Header.Set(HeaderXUrlScheme, schemeHTTPS)
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
c.Request().Header.Reset()
utils.AssertEqual(t, "http", c.Protocol())
utils.AssertEqual(t, schemeHTTP, c.Protocol())
}
// go test -run Test_Ctx_Protocol_UntrustedProxyRange
@ -2064,23 +2077,23 @@ func Test_Ctx_Protocol_UntrustedProxyRange(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.Request().Header.Set(HeaderXForwardedProto, "https")
utils.AssertEqual(t, "http", c.Protocol())
c.Request().Header.Set(HeaderXForwardedProto, schemeHTTPS)
utils.AssertEqual(t, schemeHTTP, c.Protocol())
c.Request().Header.Reset()
c.Request().Header.Set(HeaderXForwardedProtocol, "https")
utils.AssertEqual(t, "http", c.Protocol())
c.Request().Header.Set(HeaderXForwardedProtocol, schemeHTTPS)
utils.AssertEqual(t, schemeHTTP, c.Protocol())
c.Request().Header.Reset()
c.Request().Header.Set(HeaderXForwardedSsl, "on")
utils.AssertEqual(t, "http", c.Protocol())
utils.AssertEqual(t, schemeHTTP, c.Protocol())
c.Request().Header.Reset()
c.Request().Header.Set(HeaderXUrlScheme, "https")
utils.AssertEqual(t, "http", c.Protocol())
c.Request().Header.Set(HeaderXUrlScheme, schemeHTTPS)
utils.AssertEqual(t, schemeHTTP, c.Protocol())
c.Request().Header.Reset()
utils.AssertEqual(t, "http", c.Protocol())
utils.AssertEqual(t, schemeHTTP, c.Protocol())
}
// go test -run Test_Ctx_Protocol_UnTrustedProxy
@ -2090,23 +2103,23 @@ func Test_Ctx_Protocol_UnTrustedProxy(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.Request().Header.Set(HeaderXForwardedProto, "https")
utils.AssertEqual(t, "http", c.Protocol())
c.Request().Header.Set(HeaderXForwardedProto, schemeHTTPS)
utils.AssertEqual(t, schemeHTTP, c.Protocol())
c.Request().Header.Reset()
c.Request().Header.Set(HeaderXForwardedProtocol, "https")
utils.AssertEqual(t, "http", c.Protocol())
c.Request().Header.Set(HeaderXForwardedProtocol, schemeHTTPS)
utils.AssertEqual(t, schemeHTTP, c.Protocol())
c.Request().Header.Reset()
c.Request().Header.Set(HeaderXForwardedSsl, "on")
utils.AssertEqual(t, "http", c.Protocol())
utils.AssertEqual(t, schemeHTTP, c.Protocol())
c.Request().Header.Reset()
c.Request().Header.Set(HeaderXUrlScheme, "https")
utils.AssertEqual(t, "http", c.Protocol())
c.Request().Header.Set(HeaderXUrlScheme, schemeHTTPS)
utils.AssertEqual(t, schemeHTTP, c.Protocol())
c.Request().Header.Reset()
utils.AssertEqual(t, "http", c.Protocol())
utils.AssertEqual(t, schemeHTTP, c.Protocol())
}
// go test -run Test_Ctx_Query
@ -2224,7 +2237,12 @@ func Test_Ctx_SaveFile(t *testing.T) {
tempFile, err := os.CreateTemp(os.TempDir(), "test-")
utils.AssertEqual(t, nil, err)
defer os.Remove(tempFile.Name())
defer func(file *os.File) {
err := file.Close()
utils.AssertEqual(t, nil, err)
err = os.Remove(file.Name())
utils.AssertEqual(t, nil, err)
}(tempFile)
err = c.SaveFile(fh, tempFile.Name())
utils.AssertEqual(t, nil, err)
@ -2242,7 +2260,7 @@ func Test_Ctx_SaveFile(t *testing.T) {
_, err = ioWriter.Write([]byte("hello world"))
utils.AssertEqual(t, nil, err)
writer.Close()
utils.AssertEqual(t, nil, writer.Close())
req := httptest.NewRequest(MethodPost, "/test", body)
req.Header.Set("Content-Type", writer.FormDataContentType())
@ -2284,7 +2302,7 @@ func Test_Ctx_SaveFileToStorage(t *testing.T) {
_, err = ioWriter.Write([]byte("hello world"))
utils.AssertEqual(t, nil, err)
writer.Close()
utils.AssertEqual(t, nil, writer.Close())
req := httptest.NewRequest(MethodPost, "/test", body)
req.Header.Set("Content-Type", writer.FormDataContentType())
@ -2370,7 +2388,9 @@ func Test_Ctx_Download(t *testing.T) {
f, err := os.Open("./ctx.go")
utils.AssertEqual(t, nil, err)
defer f.Close()
defer func() {
utils.AssertEqual(t, nil, f.Close())
}()
expect, err := io.ReadAll(f)
utils.AssertEqual(t, nil, err)
@ -2389,7 +2409,9 @@ func Test_Ctx_SendFile(t *testing.T) {
// fetch file content
f, err := os.Open("./ctx.go")
utils.AssertEqual(t, nil, err)
defer f.Close()
defer func() {
utils.AssertEqual(t, nil, f.Close())
}()
expectFileContent, err := io.ReadAll(f)
utils.AssertEqual(t, nil, err)
// fetch file info for the not modified test case
@ -2435,7 +2457,7 @@ func Test_Ctx_SendFile_404(t *testing.T) {
return err
})
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, StatusNotFound, resp.StatusCode)
}
@ -2473,11 +2495,11 @@ func Test_Ctx_SendFile_Immutable(t *testing.T) {
for _, endpoint := range endpointsForTest {
t.Run(endpoint, func(t *testing.T) {
// 1st try
resp, err := app.Test(httptest.NewRequest("GET", endpoint, nil))
resp, err := app.Test(httptest.NewRequest(MethodGet, endpoint, nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, StatusOK, resp.StatusCode)
// 2nd try
resp, err = app.Test(httptest.NewRequest("GET", endpoint, nil))
resp, err = app.Test(httptest.NewRequest(MethodGet, endpoint, nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, StatusOK, resp.StatusCode)
})
@ -2495,9 +2517,9 @@ func Test_Ctx_SendFile_RestoreOriginalURL(t *testing.T) {
return err
})
_, err1 := app.Test(httptest.NewRequest("GET", "/?test=true", nil))
_, err1 := app.Test(httptest.NewRequest(MethodGet, "/?test=true", nil))
// second request required to confirm with zero allocation
_, err2 := app.Test(httptest.NewRequest("GET", "/?test=true", nil))
_, err2 := app.Test(httptest.NewRequest(MethodGet, "/?test=true", nil))
utils.AssertEqual(t, nil, err1)
utils.AssertEqual(t, nil, err2)
@ -2893,12 +2915,12 @@ func Test_Ctx_Render(t *testing.T) {
err := c.Render("./.github/testdata/index.tmpl", Map{
"Title": "Hello, World!",
})
utils.AssertEqual(t, nil, err)
buf := bytebufferpool.Get()
_, _ = buf.WriteString("overwrite")
_, _ = buf.WriteString("overwrite") //nolint:errcheck // This will never fail
defer bytebufferpool.Put(buf)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "<h1>Hello, World!</h1>", string(c.Response().Body()))
err = c.Render("./.github/testdata/template-non-exists.html", nil)
@ -2918,12 +2940,12 @@ func Test_Ctx_RenderWithoutLocals(t *testing.T) {
c.Locals("Title", "Hello, World!")
defer app.ReleaseCtx(c)
err := c.Render("./.github/testdata/index.tmpl", Map{})
utils.AssertEqual(t, nil, err)
buf := bytebufferpool.Get()
_, _ = buf.WriteString("overwrite")
_, _ = buf.WriteString("overwrite") //nolint:errcheck // This will never fail
defer bytebufferpool.Put(buf)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "<h1><no value></h1>", string(c.Response().Body()))
}
@ -2937,14 +2959,13 @@ func Test_Ctx_RenderWithLocals(t *testing.T) {
c.Locals("Title", "Hello, World!")
defer app.ReleaseCtx(c)
err := c.Render("./.github/testdata/index.tmpl", Map{})
utils.AssertEqual(t, nil, err)
buf := bytebufferpool.Get()
_, _ = buf.WriteString("overwrite")
_, _ = buf.WriteString("overwrite") //nolint:errcheck // This will never fail
defer bytebufferpool.Put(buf)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "<h1>Hello, World!</h1>", string(c.Response().Body()))
}
func Test_Ctx_RenderWithBind(t *testing.T) {
@ -2959,14 +2980,13 @@ func Test_Ctx_RenderWithBind(t *testing.T) {
utils.AssertEqual(t, nil, err)
defer app.ReleaseCtx(c)
err = c.Render("./.github/testdata/index.tmpl", Map{})
utils.AssertEqual(t, nil, err)
buf := bytebufferpool.Get()
_, _ = buf.WriteString("overwrite")
_, _ = buf.WriteString("overwrite") //nolint:errcheck // This will never fail
defer bytebufferpool.Put(buf)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "<h1>Hello, World!</h1>", string(c.Response().Body()))
}
func Test_Ctx_RenderWithOverwrittenBind(t *testing.T) {
@ -2982,12 +3002,12 @@ func Test_Ctx_RenderWithOverwrittenBind(t *testing.T) {
err = c.Render("./.github/testdata/index.tmpl", Map{
"Title": "Hello from Fiber!",
})
utils.AssertEqual(t, nil, err)
buf := bytebufferpool.Get()
_, _ = buf.WriteString("overwrite")
_, _ = buf.WriteString("overwrite") //nolint:errcheck // This will never fail
defer bytebufferpool.Put(buf)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "<h1>Hello from Fiber!</h1>", string(c.Response().Body()))
}
@ -3005,13 +3025,12 @@ func Test_Ctx_RenderWithBindLocals(t *testing.T) {
utils.AssertEqual(t, nil, err)
c.Locals("Summary", "Test")
defer app.ReleaseCtx(c)
err = c.Render("./.github/testdata/template.tmpl", Map{})
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "<h1>Hello, World! Test</h1>", string(c.Response().Body()))
utils.AssertEqual(t, "<h1>Hello, World! Test</h1>", string(c.Response().Body()))
}
func Test_Ctx_RenderWithLocalsAndBinding(t *testing.T) {
@ -3027,11 +3046,12 @@ func Test_Ctx_RenderWithLocalsAndBinding(t *testing.T) {
c.Locals("Title", "This is a test.")
defer app.ReleaseCtx(c)
err = c.Render("index.tmpl", Map{
"Title": "Hello, World!",
})
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "<h1>Hello, World!</h1>", string(c.Response().Body()))
}
@ -3060,8 +3080,8 @@ func Benchmark_Ctx_RenderWithLocalsAndBinding(b *testing.B) {
for n := 0; n < b.N; n++ {
err = c.Render("template.tmpl", Map{})
}
utils.AssertEqual(b, nil, err)
utils.AssertEqual(b, "<h1>Hello, World! Test</h1>", string(c.Response().Body()))
}
@ -3083,8 +3103,8 @@ func Benchmark_Ctx_RedirectToRoute(b *testing.B) {
"name": "fiber",
})
}
utils.AssertEqual(b, nil, err)
utils.AssertEqual(b, 302, c.Response().StatusCode())
utils.AssertEqual(b, "/user/fiber", string(c.Response().Header.Peek(HeaderLocation)))
}
@ -3108,8 +3128,8 @@ func Benchmark_Ctx_RedirectToRouteWithQueries(b *testing.B) {
"queries": map[string]string{"a": "a", "b": "b"},
})
}
utils.AssertEqual(b, nil, err)
utils.AssertEqual(b, 302, c.Response().StatusCode())
// analysis of query parameters with url parsing, since a map pass is always randomly ordered
location, err := url.Parse(string(c.Response().Header.Peek(HeaderLocation)))
@ -3138,8 +3158,8 @@ func Benchmark_Ctx_RenderLocals(b *testing.B) {
for n := 0; n < b.N; n++ {
err = c.Render("index.tmpl", Map{})
}
utils.AssertEqual(b, nil, err)
utils.AssertEqual(b, "<h1>Hello, World!</h1>", string(c.Response().Body()))
}
@ -3164,8 +3184,8 @@ func Benchmark_Ctx_RenderBind(b *testing.B) {
for n := 0; n < b.N; n++ {
err = c.Render("index.tmpl", Map{})
}
utils.AssertEqual(b, nil, err)
utils.AssertEqual(b, "<h1>Hello, World!</h1>", string(c.Response().Body()))
}
@ -3191,8 +3211,7 @@ func Test_Ctx_RestartRouting(t *testing.T) {
func Test_Ctx_RestartRoutingWithChangedPath(t *testing.T) {
t.Parallel()
app := New()
executedOldHandler := false
executedNewHandler := false
var executedOldHandler, executedNewHandler bool
app.Get("/old", func(c *Ctx) error {
c.Path("/new")
@ -3242,10 +3261,18 @@ type testTemplateEngine struct {
func (t *testTemplateEngine) Render(w io.Writer, name string, bind interface{}, layout ...string) error {
if len(layout) == 0 {
return t.templates.ExecuteTemplate(w, name, bind)
if err := t.templates.ExecuteTemplate(w, name, bind); err != nil {
return fmt.Errorf("failed to execute template without layout: %w", err)
}
return nil
}
_ = t.templates.ExecuteTemplate(w, name, bind)
return t.templates.ExecuteTemplate(w, layout[0], bind)
if err := t.templates.ExecuteTemplate(w, name, bind); err != nil {
return fmt.Errorf("failed to execute template: %w", err)
}
if err := t.templates.ExecuteTemplate(w, layout[0], bind); err != nil {
return fmt.Errorf("failed to execute template with layout: %w", err)
}
return nil
}
func (t *testTemplateEngine) Load() error {
@ -3324,7 +3351,6 @@ func Benchmark_Ctx_Get_Location_From_Route(b *testing.B) {
}
utils.AssertEqual(b, "/user/fiber", location)
utils.AssertEqual(b, nil, err)
}
// go test -run Test_Ctx_Get_Location_From_Route_name
@ -3407,11 +3433,11 @@ func Test_Ctx_Get_Location_From_Route_name_Optional_greedy_one_param(t *testing.
type errorTemplateEngine struct{}
func (t errorTemplateEngine) Render(w io.Writer, name string, bind interface{}, layout ...string) error {
func (errorTemplateEngine) Render(_ io.Writer, _ string, _ interface{}, _ ...string) error {
return errors.New("errorTemplateEngine")
}
func (t errorTemplateEngine) Load() error { return nil }
func (errorTemplateEngine) Load() error { return nil }
// go test -run Test_Ctx_Render_Engine_Error
func Test_Ctx_Render_Engine_Error(t *testing.T) {
@ -3429,7 +3455,10 @@ func Test_Ctx_Render_Go_Template(t *testing.T) {
t.Parallel()
file, err := os.CreateTemp(os.TempDir(), "fiber")
utils.AssertEqual(t, nil, err)
defer os.Remove(file.Name())
defer func() {
err := os.Remove(file.Name())
utils.AssertEqual(t, nil, err)
}()
_, err = file.Write([]byte("template"))
utils.AssertEqual(t, nil, err)
@ -3821,7 +3850,7 @@ func Test_Ctx_QueryParser(t *testing.T) {
}
rq := new(RequiredQuery)
c.Request().URI().SetQueryString("")
utils.AssertEqual(t, "name is empty", c.QueryParser(rq).Error())
utils.AssertEqual(t, "failed to decode: name is empty", c.QueryParser(rq).Error())
type ArrayQuery struct {
Data []string
@ -3837,7 +3866,7 @@ func Test_Ctx_QueryParser_WithSetParserDecoder(t *testing.T) {
t.Parallel()
type NonRFCTime time.Time
NonRFCConverter := func(value string) reflect.Value {
nonRFCConverter := func(value string) reflect.Value {
if v, err := time.Parse("2006-01-02", value); err == nil {
return reflect.ValueOf(v)
}
@ -3846,7 +3875,7 @@ func Test_Ctx_QueryParser_WithSetParserDecoder(t *testing.T) {
nonRFCTime := ParserType{
Customtype: NonRFCTime{},
Converter: NonRFCConverter,
Converter: nonRFCConverter,
}
SetParserDecoder(ParserConfig{
@ -3872,7 +3901,6 @@ func Test_Ctx_QueryParser_WithSetParserDecoder(t *testing.T) {
c.Request().URI().SetQueryString("date=2021-04-10&title=CustomDateTest&Body=October")
utils.AssertEqual(t, nil, c.QueryParser(q))
fmt.Println(q.Date, "q.Date")
utils.AssertEqual(t, "CustomDateTest", q.Title)
date := fmt.Sprintf("%v", q.Date)
utils.AssertEqual(t, "{0 63753609600 <nil>}", date)
@ -3907,7 +3935,7 @@ func Test_Ctx_QueryParser_Schema(t *testing.T) {
c.Request().URI().SetQueryString("namex=tom&nested.age=10")
q = new(Query1)
utils.AssertEqual(t, "name is empty", c.QueryParser(q).Error())
utils.AssertEqual(t, "failed to decode: name is empty", c.QueryParser(q).Error())
c.Request().URI().SetQueryString("name=tom&nested.agex=10")
q = new(Query1)
@ -3915,7 +3943,7 @@ func Test_Ctx_QueryParser_Schema(t *testing.T) {
c.Request().URI().SetQueryString("name=tom&test.age=10")
q = new(Query1)
utils.AssertEqual(t, "nested is empty", c.QueryParser(q).Error())
utils.AssertEqual(t, "failed to decode: nested is empty", c.QueryParser(q).Error())
type Query2 struct {
Name string `query:"name"`
@ -3933,11 +3961,11 @@ func Test_Ctx_QueryParser_Schema(t *testing.T) {
c.Request().URI().SetQueryString("nested.agex=10")
q2 = new(Query2)
utils.AssertEqual(t, "nested.age is empty", c.QueryParser(q2).Error())
utils.AssertEqual(t, "failed to decode: nested.age is empty", c.QueryParser(q2).Error())
c.Request().URI().SetQueryString("nested.agex=10")
q2 = new(Query2)
utils.AssertEqual(t, "nested.age is empty", c.QueryParser(q2).Error())
utils.AssertEqual(t, "failed to decode: nested.age is empty", c.QueryParser(q2).Error())
type Node struct {
Value int `query:"val,required"`
@ -3951,7 +3979,7 @@ func Test_Ctx_QueryParser_Schema(t *testing.T) {
c.Request().URI().SetQueryString("next.val=2")
n = new(Node)
utils.AssertEqual(t, "val is empty", c.QueryParser(n).Error())
utils.AssertEqual(t, "failed to decode: val is empty", c.QueryParser(n).Error())
c.Request().URI().SetQueryString("val=3&next.value=2")
n = new(Node)
@ -4057,7 +4085,7 @@ func Test_Ctx_ReqHeaderParser(t *testing.T) {
}
rh := new(RequiredHeader)
c.Request().Header.Del("name")
utils.AssertEqual(t, "name is empty", c.ReqHeaderParser(rh).Error())
utils.AssertEqual(t, "failed to decode: name is empty", c.ReqHeaderParser(rh).Error())
}
// go test -run Test_Ctx_ReqHeaderParser_WithSetParserDecoder -v
@ -4065,7 +4093,7 @@ func Test_Ctx_ReqHeaderParser_WithSetParserDecoder(t *testing.T) {
t.Parallel()
type NonRFCTime time.Time
NonRFCConverter := func(value string) reflect.Value {
nonRFCConverter := func(value string) reflect.Value {
if v, err := time.Parse("2006-01-02", value); err == nil {
return reflect.ValueOf(v)
}
@ -4074,7 +4102,7 @@ func Test_Ctx_ReqHeaderParser_WithSetParserDecoder(t *testing.T) {
nonRFCTime := ParserType{
Customtype: NonRFCTime{},
Converter: NonRFCConverter,
Converter: nonRFCConverter,
}
SetParserDecoder(ParserConfig{
@ -4103,7 +4131,6 @@ func Test_Ctx_ReqHeaderParser_WithSetParserDecoder(t *testing.T) {
c.Request().Header.Add("Body", "October")
utils.AssertEqual(t, nil, c.ReqHeaderParser(r))
fmt.Println(r.Date, "q.Date")
utils.AssertEqual(t, "CustomDateTest", r.Title)
date := fmt.Sprintf("%v", r.Date)
utils.AssertEqual(t, "{0 63753609600 <nil>}", date)
@ -4140,7 +4167,7 @@ func Test_Ctx_ReqHeaderParser_Schema(t *testing.T) {
c.Request().Header.Del("Name")
q = new(Header1)
utils.AssertEqual(t, "Name is empty", c.ReqHeaderParser(q).Error())
utils.AssertEqual(t, "failed to decode: Name is empty", c.ReqHeaderParser(q).Error())
c.Request().Header.Add("Name", "tom")
c.Request().Header.Del("Nested.Age")
@ -4150,7 +4177,7 @@ func Test_Ctx_ReqHeaderParser_Schema(t *testing.T) {
c.Request().Header.Del("Nested.Agex")
q = new(Header1)
utils.AssertEqual(t, "Nested is empty", c.ReqHeaderParser(q).Error())
utils.AssertEqual(t, "failed to decode: Nested is empty", c.ReqHeaderParser(q).Error())
c.Request().Header.Del("Nested.Agex")
c.Request().Header.Del("Name")
@ -4176,7 +4203,7 @@ func Test_Ctx_ReqHeaderParser_Schema(t *testing.T) {
c.Request().Header.Del("Nested.Age")
c.Request().Header.Add("Nested.Agex", "10")
h2 = new(Header2)
utils.AssertEqual(t, "Nested.age is empty", c.ReqHeaderParser(h2).Error())
utils.AssertEqual(t, "failed to decode: Nested.age is empty", c.ReqHeaderParser(h2).Error())
type Node struct {
Value int `reqHeader:"Val,required"`
@ -4191,7 +4218,7 @@ func Test_Ctx_ReqHeaderParser_Schema(t *testing.T) {
c.Request().Header.Del("Val")
n = new(Node)
utils.AssertEqual(t, "Val is empty", c.ReqHeaderParser(n).Error())
utils.AssertEqual(t, "failed to decode: Val is empty", c.ReqHeaderParser(n).Error())
c.Request().Header.Add("Val", "3")
c.Request().Header.Del("Next.Val")
@ -4628,8 +4655,9 @@ func Test_Ctx_RepeatParserWithSameStruct(t *testing.T) {
var gzipJSON bytes.Buffer
w := gzip.NewWriter(&gzipJSON)
_, _ = w.Write([]byte(`{"body_param":"body_param"}`))
_ = w.Close()
_, _ = w.Write([]byte(`{"body_param":"body_param"}`)) //nolint:errcheck // This will never fail
err := w.Close()
utils.AssertEqual(t, nil, err)
c.Request().Header.SetContentType(MIMEApplicationJSON)
c.Request().Header.Set(HeaderContentEncoding, "gzip")
c.Request().SetBody(gzipJSON.Bytes())

View File

@ -1,11 +1,10 @@
package fiber
import (
"encoding/json"
"errors"
"testing"
jerrors "encoding/json"
"github.com/gofiber/fiber/v2/internal/schema"
"github.com/gofiber/fiber/v2/utils"
)
@ -36,42 +35,42 @@ func TestMultiError(t *testing.T) {
func TestInvalidUnmarshalError(t *testing.T) {
t.Parallel()
var e *jerrors.InvalidUnmarshalError
var e *json.InvalidUnmarshalError
ok := errors.As(&InvalidUnmarshalError{}, &e)
utils.AssertEqual(t, true, ok)
}
func TestMarshalerError(t *testing.T) {
t.Parallel()
var e *jerrors.MarshalerError
var e *json.MarshalerError
ok := errors.As(&MarshalerError{}, &e)
utils.AssertEqual(t, true, ok)
}
func TestSyntaxError(t *testing.T) {
t.Parallel()
var e *jerrors.SyntaxError
var e *json.SyntaxError
ok := errors.As(&SyntaxError{}, &e)
utils.AssertEqual(t, true, ok)
}
func TestUnmarshalTypeError(t *testing.T) {
t.Parallel()
var e *jerrors.UnmarshalTypeError
var e *json.UnmarshalTypeError
ok := errors.As(&UnmarshalTypeError{}, &e)
utils.AssertEqual(t, true, ok)
}
func TestUnsupportedTypeError(t *testing.T) {
t.Parallel()
var e *jerrors.UnsupportedTypeError
var e *json.UnsupportedTypeError
ok := errors.As(&UnsupportedTypeError{}, &e)
utils.AssertEqual(t, true, ok)
}
func TestUnsupportedValeError(t *testing.T) {
t.Parallel()
var e *jerrors.UnsupportedValueError
var e *json.UnsupportedValueError
ok := errors.As(&UnsupportedValueError{}, &e)
utils.AssertEqual(t, true, ok)
}

View File

@ -168,7 +168,6 @@ func (grp *Group) Group(prefix string, handlers ...Handler) Router {
}
return newGrp
}
// Route is used to define routes with a common prefix inside the common function.

View File

@ -20,13 +20,13 @@ import (
"unsafe"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp"
)
/* #nosec */
// getTlsConfig returns a net listener's tls config
func getTlsConfig(ln net.Listener) *tls.Config {
// getTLSConfig returns a net listener's tls config
func getTLSConfig(ln net.Listener) *tls.Config {
// Get listener type
pointer := reflect.ValueOf(ln)
@ -37,12 +37,16 @@ func getTlsConfig(ln net.Listener) *tls.Config {
// Get private field from value
if field := val.FieldByName("config"); field.Type() != nil {
// Copy value from pointer field (unsafe)
newval := reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())) // #nosec G103
newval := reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())) //nolint:gosec // Probably the only way to extract the *tls.Config from a net.Listener. TODO: Verify there really is no easier way without using unsafe.
if newval.Type() != nil {
// Get element from pointer
if elem := newval.Elem(); elem.Type() != nil {
// Cast value to *tls.Config
return elem.Interface().(*tls.Config)
c, ok := elem.Interface().(*tls.Config)
if !ok {
panic(fmt.Errorf("failed to type-assert to *tls.Config"))
}
return c
}
}
}
@ -53,19 +57,21 @@ func getTlsConfig(ln net.Listener) *tls.Config {
}
// readContent opens a named file and read content from it
func readContent(rf io.ReaderFrom, name string) (n int64, err error) {
func readContent(rf io.ReaderFrom, name string) (int64, error) {
// Read file
f, err := os.Open(filepath.Clean(name))
if err != nil {
return 0, err
return 0, fmt.Errorf("failed to open: %w", err)
}
// #nosec G307
defer func() {
if err = f.Close(); err != nil {
log.Printf("Error closing file: %s\n", err)
}
}()
return rf.ReadFrom(f)
if n, err := rf.ReadFrom(f); err != nil {
return n, fmt.Errorf("failed to read: %w", err)
}
return 0, nil
}
// quoteString escape special characters in a given string
@ -78,7 +84,8 @@ func (app *App) quoteString(raw string) string {
}
// Scan stack if other methods match the request
func (app *App) methodExist(ctx *Ctx) (exist bool) {
func (app *App) methodExist(ctx *Ctx) bool {
var exists bool
methods := app.config.RequestMethods
for i := 0; i < len(methods); i++ {
// Skip original method
@ -108,7 +115,7 @@ func (app *App) methodExist(ctx *Ctx) (exist bool) {
// No match, next route
if match {
// We matched
exist = true
exists = true
// Add method to Allow header
ctx.Append(HeaderAllow, methods[i])
// Break stack loop
@ -116,7 +123,7 @@ func (app *App) methodExist(ctx *Ctx) (exist bool) {
}
}
}
return
return exists
}
// uniqueRouteStack drop all not unique routes from the slice
@ -146,7 +153,7 @@ func defaultString(value string, defaultValue []string) string {
const normalizedHeaderETag = "Etag"
// Generate and set ETag header to response
func setETag(c *Ctx, weak bool) {
func setETag(c *Ctx, weak bool) { //nolint: revive // Accepting a bool param is fine here
// Don't generate ETags for invalid responses
if c.fasthttp.Response.StatusCode() != StatusOK {
return
@ -160,7 +167,8 @@ func setETag(c *Ctx, weak bool) {
clientEtag := c.Get(HeaderIfNoneMatch)
// Generate ETag for response
crc32q := crc32.MakeTable(0xD5828281)
const pol = 0xD5828281
crc32q := crc32.MakeTable(pol)
etag := fmt.Sprintf("\"%d-%v\"", len(body), crc32.Checksum(body, crc32q))
// Enable weak tag
@ -173,7 +181,9 @@ func setETag(c *Ctx, weak bool) {
// Check if server's ETag is weak
if clientEtag[2:] == etag || clientEtag[2:] == etag[2:] {
// W/1 == 1 || W/1 == W/1
_ = c.SendStatus(StatusNotModified)
if err := c.SendStatus(StatusNotModified); err != nil {
log.Printf("setETag: failed to SendStatus: %v\n", err)
}
c.fasthttp.ResetBody()
return
}
@ -183,7 +193,9 @@ func setETag(c *Ctx, weak bool) {
}
if strings.Contains(clientEtag, etag) {
// 1 == 1
_ = c.SendStatus(StatusNotModified)
if err := c.SendStatus(StatusNotModified); err != nil {
log.Printf("setETag: failed to SendStatus: %v\n", err)
}
c.fasthttp.ResetBody()
return
}
@ -239,7 +251,7 @@ func getOffer(header string, offers ...string) string {
return ""
}
func matchEtag(s string, etag string) bool {
func matchEtag(s, etag string) bool {
if s == etag || s == "W/"+etag || "W/"+s == etag {
return true
}
@ -254,12 +266,12 @@ func (app *App) isEtagStale(etag string, noneMatchBytes []byte) bool {
// https://github.com/jshttp/fresh/blob/10e0471669dbbfbfd8de65bc6efac2ddd0bfa057/index.js#L110
for i := range noneMatchBytes {
switch noneMatchBytes[i] {
case 0x20:
case 0x20: //nolint:gomnd // This is a space (" ")
if start == end {
start = i + 1
end = i + 1
}
case 0x2c:
case 0x2c: //nolint:gomnd // This is a comma (",")
if matchEtag(app.getString(noneMatchBytes[start:end]), etag) {
return false
}
@ -273,7 +285,7 @@ func (app *App) isEtagStale(etag string, noneMatchBytes []byte) bool {
return !matchEtag(app.getString(noneMatchBytes[start:end]), etag)
}
func parseAddr(raw string) (host, port string) {
func parseAddr(raw string) (string, string) { //nolint:revive // Returns (host, port)
if i := strings.LastIndex(raw, ":"); i != -1 {
return raw[:i], raw[i+1:]
}
@ -313,21 +325,21 @@ type testConn struct {
w bytes.Buffer
}
func (c *testConn) Read(b []byte) (int, error) { return c.r.Read(b) }
func (c *testConn) Write(b []byte) (int, error) { return c.w.Write(b) }
func (c *testConn) Close() error { return nil }
func (c *testConn) Read(b []byte) (int, error) { return c.r.Read(b) } //nolint:wrapcheck // This must not be wrapped
func (c *testConn) Write(b []byte) (int, error) { return c.w.Write(b) } //nolint:wrapcheck // This must not be wrapped
func (*testConn) Close() error { return nil }
func (c *testConn) LocalAddr() net.Addr { return &net.TCPAddr{Port: 0, Zone: "", IP: net.IPv4zero} }
func (c *testConn) RemoteAddr() net.Addr { return &net.TCPAddr{Port: 0, Zone: "", IP: net.IPv4zero} }
func (c *testConn) SetDeadline(_ time.Time) error { return nil }
func (c *testConn) SetReadDeadline(_ time.Time) error { return nil }
func (c *testConn) SetWriteDeadline(_ time.Time) error { return nil }
func (*testConn) LocalAddr() net.Addr { return &net.TCPAddr{Port: 0, Zone: "", IP: net.IPv4zero} }
func (*testConn) RemoteAddr() net.Addr { return &net.TCPAddr{Port: 0, Zone: "", IP: net.IPv4zero} }
func (*testConn) SetDeadline(_ time.Time) error { return nil }
func (*testConn) SetReadDeadline(_ time.Time) error { return nil }
func (*testConn) SetWriteDeadline(_ time.Time) error { return nil }
var getStringImmutable = func(b []byte) string {
func getStringImmutable(b []byte) string {
return string(b)
}
var getBytesImmutable = func(s string) (b []byte) {
func getBytesImmutable(s string) []byte {
return []byte(s)
}
@ -335,6 +347,7 @@ var getBytesImmutable = func(s string) (b []byte) {
func (app *App) methodInt(s string) int {
// For better performance
if len(app.configured.RequestMethods) == 0 {
//nolint:gomnd // TODO: Use iota instead
switch s {
case MethodGet:
return 0
@ -391,8 +404,7 @@ func IsMethodIdempotent(m string) bool {
}
switch m {
case MethodPut,
MethodDelete:
case MethodPut, MethodDelete:
return true
default:
return false
@ -714,7 +726,7 @@ const (
ConstraintBool = "bool"
ConstraintFloat = "float"
ConstraintAlpha = "alpha"
ConstraintGuid = "guid"
ConstraintGuid = "guid" //nolint:revive,stylecheck // TODO: Rename to "ConstraintGUID" in v3
ConstraintMinLen = "minLen"
ConstraintMaxLen = "maxLen"
ConstraintLen = "len"

View File

@ -11,6 +11,7 @@ import (
"time"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)

View File

@ -1,14 +1,20 @@
package fiber
import (
"log"
)
// OnRouteHandler Handlers define a function to create hooks for Fiber.
type OnRouteHandler = func(Route) error
type OnNameHandler = OnRouteHandler
type OnGroupHandler = func(Group) error
type OnGroupNameHandler = OnGroupHandler
type OnListenHandler = func() error
type OnShutdownHandler = OnListenHandler
type OnForkHandler = func(int) error
type OnMountHandler = func(*App) error
type (
OnRouteHandler = func(Route) error
OnNameHandler = OnRouteHandler
OnGroupHandler = func(Group) error
OnGroupNameHandler = OnGroupHandler
OnListenHandler = func() error
OnShutdownHandler = OnListenHandler
OnForkHandler = func(int) error
OnMountHandler = func(*App) error
)
// Hooks is a struct to use it with App.
type Hooks struct {
@ -180,13 +186,17 @@ func (h *Hooks) executeOnListenHooks() error {
func (h *Hooks) executeOnShutdownHooks() {
for _, v := range h.onShutdown {
_ = v()
if err := v(); err != nil {
log.Printf("failed to call shutdown hook: %v\n", err)
}
}
}
func (h *Hooks) executeOnForkHooks(pid int) {
for _, v := range h.onFork {
_ = v(pid)
if err := v(pid); err != nil {
log.Printf("failed to call fork hook: %v\n", err)
}
}
}

View File

@ -7,10 +7,11 @@ import (
"time"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/bytebufferpool"
)
var testSimpleHandler = func(c *Ctx) error {
func testSimpleHandler(c *Ctx) error {
return c.SendString("simple")
}

View File

@ -10,7 +10,7 @@ var defaultPool = sync.Pool{
// AcquireDict acquire new dict.
func AcquireDict() *Dict {
return defaultPool.Get().(*Dict) // nolint:forcetypeassert
return defaultPool.Get().(*Dict)
}
// ReleaseDict release dict.

View File

@ -366,7 +366,7 @@ func HostDev(combineWith ...string) string {
// getSysctrlEnv sets LC_ALL=C in a list of env vars for use when running
// sysctl commands (see DoSysctrl).
func getSysctrlEnv(env []string) []string {
foundLC := false
var foundLC bool
for i, line := range env {
if strings.HasPrefix(line, "LC_ALL") {
env[i] = "LC_ALL=C"

View File

@ -6,8 +6,7 @@ import (
"github.com/gofiber/fiber/v2/internal/gopsutil/common"
)
//lint:ignore U1000 we need this elsewhere
var invoke common.Invoker = common.Invoke{} //nolint:all
var invoke common.Invoker = common.Invoke{} //nolint:unused // We use this only for some OS'es
// Memory usage statistics. Total, Available and Used contain numbers of bytes
// for human consumption.

View File

@ -86,7 +86,6 @@ func SwapMemory() (*SwapMemoryStat, error) {
}
// Constants from vm/vm_param.h
// nolint: golint
const (
XSWDEV_VERSION11 = 1
XSWDEV_VERSION = 2

View File

@ -57,10 +57,12 @@ func fillFromMeminfoWithContext(ctx context.Context) (*VirtualMemoryStat, *Virtu
lines, _ := common.ReadLines(filename)
// flag if MemAvailable is in /proc/meminfo (kernel 3.14+)
memavail := false
activeFile := false // "Active(file)" not available: 2.6.28 / Dec 2008
inactiveFile := false // "Inactive(file)" not available: 2.6.28 / Dec 2008
sReclaimable := false // "SReclaimable:" not available: 2.6.19 / Nov 2006
var (
memavail bool
activeFile bool // "Active(file)" not available: 2.6.28 / Dec 2008
inactiveFile bool // "Inactive(file)" not available: 2.6.28 / Dec 2008
sReclaimable bool // "SReclaimable:" not available: 2.6.19 / Nov 2006
)
ret := &VirtualMemoryStat{}
retEx := &VirtualMemoryExStat{}

View File

@ -168,7 +168,7 @@ func rwArray(dst jsWriter, src *Reader) (n int, err error) {
if err != nil {
return
}
comma := false
var comma bool
for i := uint32(0); i < sz; i++ {
if comma {
err = dst.WriteByte(',')

View File

@ -163,7 +163,7 @@ func (c *cache) createField(field reflect.StructField, parentAlias string) *fiel
}
// Check if the type is supported and don't cache it if not.
// First let's get the basic type.
isSlice, isStruct := false, false
var isSlice, isStruct bool
ft := field.Type
m := isTextUnmarshaler(reflect.Zero(ft))
if ft.Kind() == reflect.Ptr {

View File

@ -149,13 +149,13 @@ func Benchmark_Storage_Memory(b *testing.B) {
b.ResetTimer()
for n := 0; n < b.N; n++ {
for _, key := range keys {
d.Set(key, value, ttl)
_ = d.Set(key, value, ttl)
}
for _, key := range keys {
_, _ = d.Get(key)
}
for _, key := range keys {
d.Delete(key)
_ = d.Delete(key)
}
}
})

View File

@ -156,7 +156,6 @@ func (e *Engine) Load() error {
name = strings.TrimSuffix(name, e.extension)
// name = strings.Replace(name, e.extension, "", -1)
// Read the file
// #gosec G304
buf, err := utils.ReadFile(path, e.fileSystem)
if err != nil {
return err

View File

@ -21,7 +21,6 @@ func Walk(fs http.FileSystem, root string, walkFn filepath.WalkFunc) error {
return walk(fs, root, info, walkFn)
}
// #nosec G304
// ReadFile returns the raw content of a file
func ReadFile(path string, fs http.FileSystem) ([]byte, error) {
if fs != nil {

View File

@ -9,6 +9,7 @@ import (
"crypto/x509"
"errors"
"fmt"
"log"
"net"
"os"
"path/filepath"
@ -31,7 +32,7 @@ func (app *App) Listener(ln net.Listener) error {
// Print startup message
if !app.config.DisableStartupMessage {
app.startupMessage(ln.Addr().String(), getTlsConfig(ln) != nil, "")
app.startupMessage(ln.Addr().String(), getTLSConfig(ln) != nil, "")
}
// Print routes
@ -41,7 +42,7 @@ func (app *App) Listener(ln net.Listener) error {
// Prefork is not supported for custom listeners
if app.config.Prefork {
fmt.Println("[Warning] Prefork isn't supported for custom listeners.")
log.Printf("[Warning] Prefork isn't supported for custom listeners.\n")
}
// Start listening
@ -61,7 +62,7 @@ func (app *App) Listen(addr string) error {
// Setup listener
ln, err := net.Listen(app.config.Network, addr)
if err != nil {
return err
return fmt.Errorf("failed to listen: %w", err)
}
// prepare the server for the start
@ -94,7 +95,7 @@ func (app *App) ListenTLS(addr, certFile, keyFile string) error {
// Set TLS config with handler
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %s", certFile, keyFile, err)
return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %w", certFile, keyFile, err)
}
tlsHandler := &TLSHandler{}
@ -115,7 +116,7 @@ func (app *App) ListenTLS(addr, certFile, keyFile string) error {
ln, err := net.Listen(app.config.Network, addr)
ln = tls.NewListener(ln, config)
if err != nil {
return err
return fmt.Errorf("failed to listen: %w", err)
}
// prepare the server for the start
@ -150,12 +151,12 @@ func (app *App) ListenMutualTLS(addr, certFile, keyFile, clientCertFile string)
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %s", certFile, keyFile, err)
return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %w", certFile, keyFile, err)
}
clientCACert, err := os.ReadFile(filepath.Clean(clientCertFile))
if err != nil {
return err
return fmt.Errorf("failed to read file: %w", err)
}
clientCertPool := x509.NewCertPool()
clientCertPool.AppendCertsFromPEM(clientCACert)
@ -179,7 +180,7 @@ func (app *App) ListenMutualTLS(addr, certFile, keyFile, clientCertFile string)
// Setup listener
ln, err := tls.Listen(app.config.Network, addr, config)
if err != nil {
return err
return fmt.Errorf("failed to listen: %w", err)
}
// prepare the server for the start
@ -203,7 +204,7 @@ func (app *App) ListenMutualTLS(addr, certFile, keyFile, clientCertFile string)
}
// startupMessage prepares the startup message with the handler number, port, address and other information
func (app *App) startupMessage(addr string, tls bool, pids string) {
func (app *App) startupMessage(addr string, tls bool, pids string) { //nolint: revive // Accepting a bool param is fine here
// ignore child processes
if IsChild() {
return
@ -227,7 +228,8 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
}
center := func(s string, width int) string {
pad := strconv.Itoa((width - len(s)) / 2)
const padDiv = 2
pad := strconv.Itoa((width - len(s)) / padDiv)
str := fmt.Sprintf("%"+pad+"s", " ")
str += s
str += fmt.Sprintf("%"+pad+"s", " ")
@ -238,7 +240,8 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
}
centerValue := func(s string, width int) string {
pad := strconv.Itoa((width - runewidth.StringWidth(s)) / 2)
const padDiv = 2
pad := strconv.Itoa((width - runewidth.StringWidth(s)) / padDiv)
str := fmt.Sprintf("%"+pad+"s", " ")
str += fmt.Sprintf("%s%s%s", colors.Cyan, s, colors.Black)
str += fmt.Sprintf("%"+pad+"s", " ")
@ -249,13 +252,13 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
return str
}
pad := func(s string, width int) (str string) {
pad := func(s string, width int) string {
toAdd := width - len(s)
str += s
str := s
for i := 0; i < toAdd; i++ {
str += " "
}
return
return str
}
host, port := parseAddr(addr)
@ -267,9 +270,9 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
}
}
scheme := "http"
scheme := schemeHTTP
if tls {
scheme = "https"
scheme = schemeHTTPS
}
isPrefork := "Disabled"
@ -282,19 +285,18 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
procs = "1"
}
const lineLen = 49
mainLogo := colors.Black + " ┌───────────────────────────────────────────────────┐\n"
if app.config.AppName != "" {
mainLogo += " │ " + centerValue(app.config.AppName, 49) + " │\n"
mainLogo += " │ " + centerValue(app.config.AppName, lineLen) + " │\n"
}
mainLogo += " │ " + centerValue("Fiber v"+Version, 49) + " │\n"
mainLogo += " │ " + centerValue("Fiber v"+Version, lineLen) + " │\n"
if host == "0.0.0.0" {
mainLogo +=
" │ " + center(fmt.Sprintf("%s://127.0.0.1:%s", scheme, port), 49) + " │\n" +
" │ " + center(fmt.Sprintf("(bound on host 0.0.0.0 and port %s)", port), 49) + " │\n"
mainLogo += " │ " + center(fmt.Sprintf("%s://127.0.0.1:%s", scheme, port), lineLen) + " │\n" +
" │ " + center(fmt.Sprintf("(bound on host 0.0.0.0 and port %s)", port), lineLen) + " │\n"
} else {
mainLogo +=
" │ " + center(fmt.Sprintf("%s://%s:%s", scheme, host, port), 49) + " │\n"
mainLogo += " │ " + center(fmt.Sprintf("%s://%s:%s", scheme, host, port), lineLen) + " │\n"
}
mainLogo += fmt.Sprintf(
@ -303,8 +305,8 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
" │ Prefork .%s PID ....%s │\n"+
" └───────────────────────────────────────────────────┘"+
colors.Reset,
value(strconv.Itoa(int(app.handlersCount)), 14), value(procs, 12),
value(isPrefork, 14), value(strconv.Itoa(os.Getpid()), 14),
value(strconv.Itoa(int(app.handlersCount)), 14), value(procs, 12), //nolint:gomnd // Using random padding lengths is fine here
value(isPrefork, 14), value(strconv.Itoa(os.Getpid()), 14), //nolint:gomnd // Using random padding lengths is fine here
)
var childPidsLogo string
@ -329,19 +331,21 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
thisLine := "Child PIDs ... "
var itemsOnThisLine []string
const maxLineLen = 49
addLine := func() {
lines = append(lines,
fmt.Sprintf(
newLine,
colors.Black,
thisLine+colors.Cyan+pad(strings.Join(itemsOnThisLine, ", "), 49-len(thisLine)),
thisLine+colors.Cyan+pad(strings.Join(itemsOnThisLine, ", "), maxLineLen-len(thisLine)),
colors.Black,
),
)
}
for _, pid := range pidSlice {
if len(thisLine+strings.Join(append(itemsOnThisLine, pid), ", ")) > 49 {
if len(thisLine+strings.Join(append(itemsOnThisLine, pid), ", ")) > maxLineLen {
addLine()
thisLine = ""
itemsOnThisLine = []string{pid}
@ -415,7 +419,7 @@ func (app *App) printRoutesMessage() {
var routes []RouteMessage
for _, routeStack := range app.stack {
for _, route := range routeStack {
var newRoute = RouteMessage{}
var newRoute RouteMessage
newRoute.name = route.Name
newRoute.method = route.Method
newRoute.path = route.Path
@ -443,5 +447,5 @@ func (app *App) printRoutesMessage() {
_, _ = fmt.Fprintf(w, "%s%s\t%s| %s%s\t%s| %s%s\t%s| %s%s\n", colors.Blue, route.method, colors.White, colors.Green, route.path, colors.White, colors.Cyan, route.name, colors.White, colors.Yellow, route.handlers)
}
_ = w.Flush()
_ = w.Flush() //nolint:errcheck // It is fine to ignore the error here
}

View File

@ -7,7 +7,6 @@ package fiber
import (
"bytes"
"crypto/tls"
"fmt"
"io"
"log"
"os"
@ -17,6 +16,7 @@ import (
"time"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp/fasthttputil"
)
@ -125,8 +125,10 @@ func Test_App_Listener_TLS_Listener(t *testing.T) {
if err != nil {
utils.AssertEqual(t, nil, err)
}
//nolint:gosec // We're in a test so using old ciphers is fine
config := &tls.Config{Certificates: []tls.Certificate{cer}}
//nolint:gosec // We're in a test so listening on all interfaces is fine
ln, err := tls.Listen(NetworkTCP4, ":0", config)
utils.AssertEqual(t, nil, err)
@ -182,7 +184,6 @@ func Test_App_Master_Process_Show_Startup_Message(t *testing.T) {
New(Config{Prefork: true}).
startupMessage(":3000", true, strings.Repeat(",11111,22222,33333,44444,55555,60000", 10))
})
fmt.Println(startupMessage)
utils.AssertEqual(t, true, strings.Contains(startupMessage, "https://127.0.0.1:3000"))
utils.AssertEqual(t, true, strings.Contains(startupMessage, "(bound on host 0.0.0.0 and port 3000)"))
utils.AssertEqual(t, true, strings.Contains(startupMessage, "Child PIDs"))
@ -196,7 +197,6 @@ func Test_App_Master_Process_Show_Startup_MessageWithAppName(t *testing.T) {
startupMessage := captureOutput(func() {
app.startupMessage(":3000", true, strings.Repeat(",11111,22222,33333,44444,55555,60000", 10))
})
fmt.Println(startupMessage)
utils.AssertEqual(t, "Test App v1.0.1", app.Config().AppName)
utils.AssertEqual(t, true, strings.Contains(startupMessage, app.Config().AppName))
}
@ -208,7 +208,6 @@ func Test_App_Master_Process_Show_Startup_MessageWithAppNameNonAscii(t *testing.
startupMessage := captureOutput(func() {
app.startupMessage(":3000", false, "")
})
fmt.Println(startupMessage)
utils.AssertEqual(t, true, strings.Contains(startupMessage, "│ Serveur de vérification des données │"))
}
@ -219,8 +218,7 @@ func Test_App_print_Route(t *testing.T) {
printRoutesMessage := captureOutput(func() {
app.printRoutesMessage()
})
fmt.Println(printRoutesMessage)
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "GET"))
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, MethodGet))
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "/"))
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "emptyHandler"))
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "routeName"))
@ -240,11 +238,11 @@ func Test_App_print_Route_with_group(t *testing.T) {
app.printRoutesMessage()
})
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "GET"))
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, MethodGet))
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "/"))
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "emptyHandler"))
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "/v1/test"))
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "POST"))
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, MethodPost))
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "/v1/test/fiber"))
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "PUT"))
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "/v1/test/fiber/*"))

View File

@ -1,15 +1,15 @@
package basicauth
import (
"encoding/base64"
"fmt"
"io"
"net/http/httptest"
"testing"
b64 "encoding/base64"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
@ -23,7 +23,7 @@ func Test_BasicAuth_Next(t *testing.T) {
},
}))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
}
@ -39,6 +39,7 @@ func Test_Middleware_BasicAuth(t *testing.T) {
},
}))
//nolint:forcetypeassert,errcheck // TODO: Do not force-type assert
app.Get("/testauth", func(c *fiber.Ctx) error {
username := c.Locals("username").(string)
password := c.Locals("password").(string)
@ -74,9 +75,9 @@ func Test_Middleware_BasicAuth(t *testing.T) {
for _, tt := range tests {
// Base64 encode credentials for http auth header
creds := b64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", tt.username, tt.password)))
creds := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", tt.username, tt.password)))
req := httptest.NewRequest("GET", "/testauth", nil)
req := httptest.NewRequest(fiber.MethodGet, "/testauth", nil)
req.Header.Add("Authorization", "Basic "+creds)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
@ -108,7 +109,7 @@ func Benchmark_Middleware_BasicAuth(b *testing.B) {
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod("GET")
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/")
fctx.Request.Header.Set(fiber.HeaderAuthorization, "basic am9objpkb2U=") // john:doe

View File

@ -53,6 +53,8 @@ type Config struct {
}
// ConfigDefault is the default config
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{
Next: nil,
Users: map[string]string{},

View File

@ -34,6 +34,7 @@ const (
noStore = "no-store"
)
//nolint:gochecknoglobals // TODO: Do not use a global var here
var ignoreHeaders = map[string]interface{}{
"Connection": nil,
"Keep-Alive": nil,
@ -43,8 +44,8 @@ var ignoreHeaders = map[string]interface{}{
"Trailers": nil,
"Transfer-Encoding": nil,
"Upgrade": nil,
"Content-Type": nil, // already stored explicitely by the cache manager
"Content-Encoding": nil, // already stored explicitely by the cache manager
"Content-Type": nil, // already stored explicitly by the cache manager
"Content-Encoding": nil, // already stored explicitly by the cache manager
}
// New creates a new middleware handler
@ -69,7 +70,7 @@ func New(config ...Config) fiber.Handler {
// Create indexed heap for tracking expirations ( see heap.go )
heap := &indexedHeap{}
// count stored bytes (sizes of response bodies)
var storedBytes uint = 0
var storedBytes uint
// Update timestamp in the configured interval
go func() {
@ -81,10 +82,10 @@ func New(config ...Config) fiber.Handler {
// Delete key from both manager and storage
deleteKey := func(dkey string) {
manager.delete(dkey)
manager.del(dkey)
// External storage saves body data with different key
if cfg.Storage != nil {
manager.delete(dkey + "_body")
manager.del(dkey + "_body")
}
}
@ -205,7 +206,7 @@ func New(config ...Config) fiber.Handler {
if cfg.StoreResponseHeaders {
e.headers = make(map[string][]byte)
c.Response().Header.VisitAll(
func(key []byte, value []byte) {
func(key, value []byte) {
// create real copy
keyS := string(key)
if _, ok := ignoreHeaders[keyS]; !ok {

View File

@ -7,7 +7,6 @@ import (
"fmt"
"io"
"math"
"net/http"
"net/http/httptest"
"os"
"strconv"
@ -18,6 +17,7 @@ import (
"github.com/gofiber/fiber/v2/internal/storage/memory"
"github.com/gofiber/fiber/v2/middleware/etag"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
@ -35,10 +35,10 @@ func Test_Cache_CacheControl(t *testing.T) {
return c.SendString("Hello, World!")
})
_, err := app.Test(httptest.NewRequest("GET", "/", nil))
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "public, max-age=10", resp.Header.Get(fiber.HeaderCacheControl))
}
@ -53,7 +53,7 @@ func Test_Cache_Expired(t *testing.T) {
return c.SendString(fmt.Sprintf("%d", time.Now().UnixNano()))
})
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
@ -61,7 +61,7 @@ func Test_Cache_Expired(t *testing.T) {
// Sleep until the cache is expired
time.Sleep(3 * time.Second)
respCached, err := app.Test(httptest.NewRequest("GET", "/", nil))
respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
bodyCached, err := io.ReadAll(respCached.Body)
utils.AssertEqual(t, nil, err)
@ -71,7 +71,7 @@ func Test_Cache_Expired(t *testing.T) {
}
// Next response should be also cached
respCachedNextRound, err := app.Test(httptest.NewRequest("GET", "/", nil))
respCachedNextRound, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
bodyCachedNextRound, err := io.ReadAll(respCachedNextRound.Body)
utils.AssertEqual(t, nil, err)
@ -92,11 +92,11 @@ func Test_Cache(t *testing.T) {
return c.SendString(now)
})
req := httptest.NewRequest("GET", "/", nil)
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
cachedReq := httptest.NewRequest("GET", "/", nil)
cachedReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
cachedResp, err := app.Test(cachedReq)
utils.AssertEqual(t, nil, err)
@ -120,31 +120,31 @@ func Test_Cache_WithNoCacheRequestDirective(t *testing.T) {
})
// Request id = 1
req := httptest.NewRequest("GET", "/", nil)
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
resp, err := app.Test(req)
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache"))
utils.AssertEqual(t, []byte("1"), body)
// Response cached, entry id = 1
// Request id = 2 without Cache-Control: no-cache
cachedReq := httptest.NewRequest("GET", "/?id=2", nil)
cachedReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
cachedResp, err := app.Test(cachedReq)
defer cachedResp.Body.Close()
cachedBody, _ := io.ReadAll(cachedResp.Body)
utils.AssertEqual(t, nil, err)
cachedBody, err := io.ReadAll(cachedResp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheHit, cachedResp.Header.Get("X-Cache"))
utils.AssertEqual(t, []byte("1"), cachedBody)
// Response not cached, returns cached response, entry id = 1
// Request id = 2 with Cache-Control: no-cache
noCacheReq := httptest.NewRequest("GET", "/?id=2", nil)
noCacheReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
noCacheReq.Header.Set(fiber.HeaderCacheControl, noCache)
noCacheResp, err := app.Test(noCacheReq)
defer noCacheResp.Body.Close()
noCacheBody, _ := io.ReadAll(noCacheResp.Body)
utils.AssertEqual(t, nil, err)
noCacheBody, err := io.ReadAll(noCacheResp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheMiss, noCacheResp.Header.Get("X-Cache"))
utils.AssertEqual(t, []byte("2"), noCacheBody)
@ -152,21 +152,21 @@ func Test_Cache_WithNoCacheRequestDirective(t *testing.T) {
/* Check Test_Cache_WithETagAndNoCacheRequestDirective */
// Request id = 2 with Cache-Control: no-cache again
noCacheReq1 := httptest.NewRequest("GET", "/?id=2", nil)
noCacheReq1 := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
noCacheReq1.Header.Set(fiber.HeaderCacheControl, noCache)
noCacheResp1, err := app.Test(noCacheReq1)
defer noCacheResp1.Body.Close()
noCacheBody1, _ := io.ReadAll(noCacheResp1.Body)
utils.AssertEqual(t, nil, err)
noCacheBody1, err := io.ReadAll(noCacheResp1.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheMiss, noCacheResp1.Header.Get("X-Cache"))
utils.AssertEqual(t, []byte("2"), noCacheBody1)
// Response cached, returns updated response, entry = 2
// Request id = 1 without Cache-Control: no-cache
cachedReq1 := httptest.NewRequest("GET", "/", nil)
cachedReq1 := httptest.NewRequest(fiber.MethodGet, "/", nil)
cachedResp1, err := app.Test(cachedReq1)
defer cachedResp1.Body.Close()
cachedBody1, _ := io.ReadAll(cachedResp1.Body)
utils.AssertEqual(t, nil, err)
cachedBody1, err := io.ReadAll(cachedResp1.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheHit, cachedResp1.Header.Get("X-Cache"))
utils.AssertEqual(t, []byte("2"), cachedBody1)
@ -188,7 +188,7 @@ func Test_Cache_WithETagAndNoCacheRequestDirective(t *testing.T) {
})
// Request id = 1
req := httptest.NewRequest("GET", "/", nil)
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache"))
@ -199,7 +199,7 @@ func Test_Cache_WithETagAndNoCacheRequestDirective(t *testing.T) {
etagToken := resp.Header.Get("Etag")
// Request id = 2 with ETag but without Cache-Control: no-cache
cachedReq := httptest.NewRequest("GET", "/?id=2", nil)
cachedReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
cachedReq.Header.Set(fiber.HeaderIfNoneMatch, etagToken)
cachedResp, err := app.Test(cachedReq)
utils.AssertEqual(t, nil, err)
@ -208,7 +208,7 @@ func Test_Cache_WithETagAndNoCacheRequestDirective(t *testing.T) {
// Response not cached, returns cached response, entry id = 1, status not modified
// Request id = 2 with ETag and Cache-Control: no-cache
noCacheReq := httptest.NewRequest("GET", "/?id=2", nil)
noCacheReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
noCacheReq.Header.Set(fiber.HeaderCacheControl, noCache)
noCacheReq.Header.Set(fiber.HeaderIfNoneMatch, etagToken)
noCacheResp, err := app.Test(noCacheReq)
@ -221,7 +221,7 @@ func Test_Cache_WithETagAndNoCacheRequestDirective(t *testing.T) {
etagToken = noCacheResp.Header.Get("Etag")
// Request id = 2 with ETag and Cache-Control: no-cache again
noCacheReq1 := httptest.NewRequest("GET", "/?id=2", nil)
noCacheReq1 := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
noCacheReq1.Header.Set(fiber.HeaderCacheControl, noCache)
noCacheReq1.Header.Set(fiber.HeaderIfNoneMatch, etagToken)
noCacheResp1, err := app.Test(noCacheReq1)
@ -231,7 +231,7 @@ func Test_Cache_WithETagAndNoCacheRequestDirective(t *testing.T) {
// Response cached, returns updated response, entry id = 2, status not modified
// Request id = 1 without ETag and Cache-Control: no-cache
cachedReq1 := httptest.NewRequest("GET", "/", nil)
cachedReq1 := httptest.NewRequest(fiber.MethodGet, "/", nil)
cachedResp1, err := app.Test(cachedReq1)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheHit, cachedResp1.Header.Get("X-Cache"))
@ -251,11 +251,11 @@ func Test_Cache_WithNoStoreRequestDirective(t *testing.T) {
})
// Request id = 2
noStoreReq := httptest.NewRequest("GET", "/?id=2", nil)
noStoreReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
noStoreReq.Header.Set(fiber.HeaderCacheControl, noStore)
noStoreResp, err := app.Test(noStoreReq)
defer noStoreResp.Body.Close()
noStoreBody, _ := io.ReadAll(noStoreResp.Body)
utils.AssertEqual(t, nil, err)
noStoreBody, err := io.ReadAll(noStoreResp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, []byte("2"), noStoreBody)
// Response not cached, returns updated response
@ -278,11 +278,11 @@ func Test_Cache_WithSeveralRequests(t *testing.T) {
for runs := 0; runs < 10; runs++ {
for i := 0; i < 10; i++ {
func(id int) {
rsp, err := app.Test(httptest.NewRequest(http.MethodGet, fmt.Sprintf("/%d", id), nil))
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, fmt.Sprintf("/%d", id), nil))
utils.AssertEqual(t, nil, err)
defer func(Body io.ReadCloser) {
err := Body.Close()
defer func(body io.ReadCloser) {
err := body.Close()
utils.AssertEqual(t, nil, err)
}(rsp.Body)
@ -311,11 +311,11 @@ func Test_Cache_Invalid_Expiration(t *testing.T) {
return c.SendString(now)
})
req := httptest.NewRequest("GET", "/", nil)
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
cachedReq := httptest.NewRequest("GET", "/", nil)
cachedReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
cachedResp, err := app.Test(cachedReq)
utils.AssertEqual(t, nil, err)
@ -342,25 +342,25 @@ func Test_Cache_Get(t *testing.T) {
return c.SendString(c.Query("cache"))
})
resp, err := app.Test(httptest.NewRequest("POST", "/?cache=123", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=123", nil))
utils.AssertEqual(t, nil, err)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "123", string(body))
resp, err = app.Test(httptest.NewRequest("POST", "/?cache=12345", nil))
resp, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=12345", nil))
utils.AssertEqual(t, nil, err)
body, err = io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "12345", string(body))
resp, err = app.Test(httptest.NewRequest("GET", "/get?cache=123", nil))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=123", nil))
utils.AssertEqual(t, nil, err)
body, err = io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "123", string(body))
resp, err = app.Test(httptest.NewRequest("GET", "/get?cache=12345", nil))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=12345", nil))
utils.AssertEqual(t, nil, err)
body, err = io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
@ -384,25 +384,25 @@ func Test_Cache_Post(t *testing.T) {
return c.SendString(c.Query("cache"))
})
resp, err := app.Test(httptest.NewRequest("POST", "/?cache=123", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=123", nil))
utils.AssertEqual(t, nil, err)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "123", string(body))
resp, err = app.Test(httptest.NewRequest("POST", "/?cache=12345", nil))
resp, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=12345", nil))
utils.AssertEqual(t, nil, err)
body, err = io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "123", string(body))
resp, err = app.Test(httptest.NewRequest("GET", "/get?cache=123", nil))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=123", nil))
utils.AssertEqual(t, nil, err)
body, err = io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "123", string(body))
resp, err = app.Test(httptest.NewRequest("GET", "/get?cache=12345", nil))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=12345", nil))
utils.AssertEqual(t, nil, err)
body, err = io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
@ -420,14 +420,14 @@ func Test_Cache_NothingToCache(t *testing.T) {
return c.SendString(time.Now().String())
})
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
time.Sleep(500 * time.Millisecond)
respCached, err := app.Test(httptest.NewRequest("GET", "/", nil))
respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
bodyCached, err := io.ReadAll(respCached.Body)
utils.AssertEqual(t, nil, err)
@ -457,22 +457,22 @@ func Test_Cache_CustomNext(t *testing.T) {
return c.Status(fiber.StatusInternalServerError).SendString(time.Now().String())
})
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
respCached, err := app.Test(httptest.NewRequest("GET", "/", nil))
respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
bodyCached, err := io.ReadAll(respCached.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, bytes.Equal(body, bodyCached))
utils.AssertEqual(t, true, respCached.Header.Get(fiber.HeaderCacheControl) != "")
_, err = app.Test(httptest.NewRequest("GET", "/error", nil))
_, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/error", nil))
utils.AssertEqual(t, nil, err)
errRespCached, err := app.Test(httptest.NewRequest("GET", "/error", nil))
errRespCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/error", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, errRespCached.Header.Get(fiber.HeaderCacheControl) == "")
}
@ -491,7 +491,7 @@ func Test_CustomKey(t *testing.T) {
return c.SendString("hi")
})
req := httptest.NewRequest("GET", "/", nil)
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
_, err := app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, called)
@ -505,7 +505,9 @@ func Test_CustomExpiration(t *testing.T) {
var newCacheTime int
app.Use(New(Config{ExpirationGenerator: func(c *fiber.Ctx, cfg *Config) time.Duration {
called = true
newCacheTime, _ = strconv.Atoi(c.GetRespHeader("Cache-Time", "600"))
var err error
newCacheTime, err = strconv.Atoi(c.GetRespHeader("Cache-Time", "600"))
utils.AssertEqual(t, nil, err)
return time.Second * time.Duration(newCacheTime)
}}))
@ -515,7 +517,7 @@ func Test_CustomExpiration(t *testing.T) {
return c.SendString(now)
})
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, called)
utils.AssertEqual(t, 1, newCacheTime)
@ -523,7 +525,7 @@ func Test_CustomExpiration(t *testing.T) {
// Sleep until the cache is expired
time.Sleep(1 * time.Second)
cachedResp, err := app.Test(httptest.NewRequest("GET", "/", nil))
cachedResp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
body, err := io.ReadAll(resp.Body)
@ -536,7 +538,7 @@ func Test_CustomExpiration(t *testing.T) {
}
// Next response should be cached
cachedRespNextRound, err := app.Test(httptest.NewRequest("GET", "/", nil))
cachedRespNextRound, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
cachedBodyNextRound, err := io.ReadAll(cachedRespNextRound.Body)
utils.AssertEqual(t, nil, err)
@ -559,12 +561,12 @@ func Test_AdditionalE2EResponseHeaders(t *testing.T) {
return c.SendString("hi")
})
req := httptest.NewRequest("GET", "/", nil)
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "foobar", resp.Header.Get("X-Foobar"))
req = httptest.NewRequest("GET", "/", nil)
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
resp, err = app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "foobar", resp.Header.Get("X-Foobar"))
@ -594,19 +596,19 @@ func Test_CacheHeader(t *testing.T) {
return c.Status(fiber.StatusInternalServerError).SendString(time.Now().String())
})
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache"))
resp, err = app.Test(httptest.NewRequest("GET", "/", nil))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheHit, resp.Header.Get("X-Cache"))
resp, err = app.Test(httptest.NewRequest("POST", "/?cache=12345", nil))
resp, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=12345", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheUnreachable, resp.Header.Get("X-Cache"))
errRespCached, err := app.Test(httptest.NewRequest("GET", "/error", nil))
errRespCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/error", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheUnreachable, errRespCached.Header.Get("X-Cache"))
}
@ -622,12 +624,12 @@ func Test_Cache_WithHead(t *testing.T) {
return c.SendString(now)
})
req := httptest.NewRequest("HEAD", "/", nil)
req := httptest.NewRequest(fiber.MethodHead, "/", nil)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache"))
cachedReq := httptest.NewRequest("HEAD", "/", nil)
cachedReq := httptest.NewRequest(fiber.MethodHead, "/", nil)
cachedResp, err := app.Test(cachedReq)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheHit, cachedResp.Header.Get("X-Cache"))
@ -649,28 +651,28 @@ func Test_Cache_WithHeadThenGet(t *testing.T) {
return c.SendString(c.Query("cache"))
})
headResp, err := app.Test(httptest.NewRequest("HEAD", "/?cache=123", nil))
headResp, err := app.Test(httptest.NewRequest(fiber.MethodHead, "/?cache=123", nil))
utils.AssertEqual(t, nil, err)
headBody, err := io.ReadAll(headResp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "", string(headBody))
utils.AssertEqual(t, cacheMiss, headResp.Header.Get("X-Cache"))
headResp, err = app.Test(httptest.NewRequest("HEAD", "/?cache=123", nil))
headResp, err = app.Test(httptest.NewRequest(fiber.MethodHead, "/?cache=123", nil))
utils.AssertEqual(t, nil, err)
headBody, err = io.ReadAll(headResp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "", string(headBody))
utils.AssertEqual(t, cacheHit, headResp.Header.Get("X-Cache"))
getResp, err := app.Test(httptest.NewRequest("GET", "/?cache=123", nil))
getResp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?cache=123", nil))
utils.AssertEqual(t, nil, err)
getBody, err := io.ReadAll(getResp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "123", string(getBody))
utils.AssertEqual(t, cacheMiss, getResp.Header.Get("X-Cache"))
getResp, err = app.Test(httptest.NewRequest("GET", "/?cache=123", nil))
getResp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/?cache=123", nil))
utils.AssertEqual(t, nil, err)
getBody, err = io.ReadAll(getResp.Body)
utils.AssertEqual(t, nil, err)
@ -691,7 +693,7 @@ func Test_CustomCacheHeader(t *testing.T) {
return c.SendString("Hello, World!")
})
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheMiss, resp.Header.Get("Cache-Status"))
}
@ -702,7 +704,7 @@ func Test_CustomCacheHeader(t *testing.T) {
func stableAscendingExpiration() func(c1 *fiber.Ctx, c2 *Config) time.Duration {
i := 0
return func(c1 *fiber.Ctx, c2 *Config) time.Duration {
i += 1
i++
return time.Hour * time.Duration(i)
}
}
@ -738,7 +740,7 @@ func Test_Cache_MaxBytesOrder(t *testing.T) {
}
for idx, tcase := range cases {
rsp, err := app.Test(httptest.NewRequest("GET", tcase[0], nil))
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, tcase[0], nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, tcase[1], rsp.Header.Get("X-Cache"), fmt.Sprintf("Case %v", idx))
}
@ -756,7 +758,8 @@ func Test_Cache_MaxBytesSizes(t *testing.T) {
app.Get("/*", func(c *fiber.Ctx) error {
path := c.Context().URI().LastPathSegment()
size, _ := strconv.Atoi(string(path))
size, err := strconv.Atoi(string(path))
utils.AssertEqual(t, nil, err)
return c.Send(make([]byte, size))
})
@ -772,7 +775,7 @@ func Test_Cache_MaxBytesSizes(t *testing.T) {
}
for idx, tcase := range cases {
rsp, err := app.Test(httptest.NewRequest("GET", tcase[0], nil))
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, tcase[0], nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, tcase[1], rsp.Header.Get("X-Cache"), fmt.Sprintf("Case %v", idx))
}
@ -785,14 +788,14 @@ func Benchmark_Cache(b *testing.B) {
app.Use(New())
app.Get("/demo", func(c *fiber.Ctx) error {
data, _ := os.ReadFile("../../.github/README.md")
data, _ := os.ReadFile("../../.github/README.md") //nolint:errcheck // We're inside a benchmark
return c.Status(fiber.StatusTeapot).Send(data)
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod("GET")
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/demo")
b.ReportAllocs()
@ -815,14 +818,14 @@ func Benchmark_Cache_Storage(b *testing.B) {
}))
app.Get("/demo", func(c *fiber.Ctx) error {
data, _ := os.ReadFile("../../.github/README.md")
data, _ := os.ReadFile("../../.github/README.md") //nolint:errcheck // We're inside a benchmark
return c.Status(fiber.StatusTeapot).Send(data)
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod("GET")
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/demo")
b.ReportAllocs()
@ -850,7 +853,7 @@ func Benchmark_Cache_AdditionalHeaders(b *testing.B) {
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod("GET")
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/demo")
b.ReportAllocs()
@ -882,7 +885,7 @@ func Benchmark_Cache_MaxSize(b *testing.B) {
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod("GET")
fctx.Request.Header.SetMethod(fiber.MethodGet)
b.ReportAllocs()
b.ResetTimer()

View File

@ -1,7 +1,7 @@
package cache
import (
"fmt"
"log"
"time"
"github.com/gofiber/fiber/v2"
@ -49,10 +49,10 @@ type Config struct {
// Default: an in memory store for this process only
Storage fiber.Storage
// Deprecated, use Storage instead
// Deprecated: Use Storage instead
Store fiber.Storage
// Deprecated, use KeyGenerator instead
// Deprecated: Use KeyGenerator instead
Key func(*fiber.Ctx) string
// allows you to store additional headers generated by next middlewares & handler
@ -75,6 +75,8 @@ type Config struct {
}
// ConfigDefault is the default config
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{
Next: nil,
Expiration: 1 * time.Minute,
@ -102,11 +104,11 @@ func configDefault(config ...Config) Config {
// Set default values
if cfg.Store != nil {
fmt.Println("[CACHE] Store is deprecated, please use Storage")
log.Printf("[CACHE] Store is deprecated, please use Storage\n")
cfg.Storage = cfg.Store
}
if cfg.Key != nil {
fmt.Println("[CACHE] Key is deprecated, please use KeyGenerator")
log.Printf("[CACHE] Key is deprecated, please use KeyGenerator\n")
cfg.KeyGenerator = cfg.Key
}
if cfg.Next == nil {

View File

@ -41,7 +41,7 @@ func (h indexedHeap) Swap(i, j int) {
}
func (h *indexedHeap) Push(x interface{}) {
h.pushInternal(x.(heapEntry))
h.pushInternal(x.(heapEntry)) //nolint:forcetypeassert // Forced type assertion required to implement the heap.Interface interface
}
func (h *indexedHeap) Pop() interface{} {
@ -65,7 +65,7 @@ func (h *indexedHeap) put(key string, exp uint64, bytes uint) int {
idx = h.entries[:n+1][n].idx
} else {
idx = h.maxidx
h.maxidx += 1
h.maxidx++
h.indices = append(h.indices, idx)
}
// Push manually to avoid allocation
@ -77,7 +77,7 @@ func (h *indexedHeap) put(key string, exp uint64, bytes uint) int {
}
func (h *indexedHeap) removeInternal(realIdx int) (string, uint) {
x := heap.Remove(h, realIdx).(heapEntry)
x := heap.Remove(h, realIdx).(heapEntry) //nolint:forcetypeassert,errcheck // Forced type assertion required to implement the heap.Interface interface
return x.key, x.bytes
}

View File

@ -51,7 +51,7 @@ func newManager(storage fiber.Storage) *manager {
// acquire returns an *entry from the sync.Pool
func (m *manager) acquire() *item {
return m.pool.Get().(*item)
return m.pool.Get().(*item) //nolint:forcetypeassert // We store nothing else in the pool
}
// release and reset *entry to sync.Pool
@ -69,38 +69,47 @@ func (m *manager) release(e *item) {
}
// get data from storage or memory
func (m *manager) get(key string) (it *item) {
func (m *manager) get(key string) *item {
var it *item
if m.storage != nil {
it = m.acquire()
if raw, _ := m.storage.Get(key); raw != nil {
raw, err := m.storage.Get(key)
if err != nil {
return it
}
if raw != nil {
if _, err := it.UnmarshalMsg(raw); err != nil {
return
return it
}
}
return
return it
}
if it, _ = m.memory.Get(key).(*item); it == nil {
if it, _ = m.memory.Get(key).(*item); it == nil { //nolint:errcheck // We store nothing else in the pool
it = m.acquire()
return it
}
return
return it
}
// get raw data from storage or memory
func (m *manager) getRaw(key string) (raw []byte) {
func (m *manager) getRaw(key string) []byte {
var raw []byte
if m.storage != nil {
raw, _ = m.storage.Get(key)
raw, _ = m.storage.Get(key) //nolint:errcheck // TODO: Handle error here
} else {
raw, _ = m.memory.Get(key).([]byte)
raw, _ = m.memory.Get(key).([]byte) //nolint:errcheck // TODO: Handle error here
}
return
return raw
}
// set data to storage or memory
func (m *manager) set(key string, it *item, exp time.Duration) {
if m.storage != nil {
if raw, err := it.MarshalMsg(nil); err == nil {
_ = m.storage.Set(key, raw, exp)
_ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Handle error here
}
// we can release data because it's serialized to database
m.release(it)
} else {
m.memory.Set(key, it, exp)
}
@ -109,16 +118,16 @@ func (m *manager) set(key string, it *item, exp time.Duration) {
// set data to storage or memory
func (m *manager) setRaw(key string, raw []byte, exp time.Duration) {
if m.storage != nil {
_ = m.storage.Set(key, raw, exp)
_ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Handle error here
} else {
m.memory.Set(key, raw, exp)
}
}
// delete data from storage or memory
func (m *manager) delete(key string) {
func (m *manager) del(key string) {
if m.storage != nil {
_ = m.storage.Delete(key)
_ = m.storage.Delete(key) //nolint:errcheck // TODO: Handle error here
} else {
m.memory.Delete(key)
}

View File

@ -2,6 +2,7 @@ package compress
import (
"github.com/gofiber/fiber/v2"
"github.com/valyala/fasthttp"
)

View File

@ -12,8 +12,10 @@ import (
"github.com/gofiber/fiber/v2/utils"
)
//nolint:gochecknoglobals // Using a global var is fine here
var filedata []byte
//nolint:gochecknoinits // init() is used to populate a global var from a README file
func init() {
dat, err := os.ReadFile("../../.github/README.md")
if err != nil {
@ -34,7 +36,7 @@ func Test_Compress_Gzip(t *testing.T) {
return c.Send(filedata)
})
req := httptest.NewRequest("GET", "/", nil)
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Set("Accept-Encoding", "gzip")
resp, err := app.Test(req)
@ -64,7 +66,7 @@ func Test_Compress_Different_Level(t *testing.T) {
return c.Send(filedata)
})
req := httptest.NewRequest("GET", "/", nil)
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Set("Accept-Encoding", "gzip")
resp, err := app.Test(req)
@ -90,7 +92,7 @@ func Test_Compress_Deflate(t *testing.T) {
return c.Send(filedata)
})
req := httptest.NewRequest("GET", "/", nil)
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Set("Accept-Encoding", "deflate")
resp, err := app.Test(req)
@ -114,7 +116,7 @@ func Test_Compress_Brotli(t *testing.T) {
return c.Send(filedata)
})
req := httptest.NewRequest("GET", "/", nil)
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Set("Accept-Encoding", "br")
resp, err := app.Test(req, 10000)
@ -138,7 +140,7 @@ func Test_Compress_Disabled(t *testing.T) {
return c.Send(filedata)
})
req := httptest.NewRequest("GET", "/", nil)
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Set("Accept-Encoding", "br")
resp, err := app.Test(req)
@ -162,7 +164,7 @@ func Test_Compress_Next_Error(t *testing.T) {
return errors.New("next error")
})
req := httptest.NewRequest("GET", "/", nil)
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Set("Accept-Encoding", "gzip")
resp, err := app.Test(req)
@ -185,7 +187,7 @@ func Test_Compress_Next(t *testing.T) {
},
}))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
}

View File

@ -33,6 +33,8 @@ const (
)
// ConfigDefault is the default config
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{
Next: nil,
Level: LevelDefault,

View File

@ -1,7 +1,6 @@
package cors
import (
"net/http"
"strconv"
"strings"
@ -54,6 +53,8 @@ type Config struct {
}
// ConfigDefault is the default config
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{
Next: nil,
AllowOrigins: "*",
@ -128,7 +129,7 @@ func New(config ...Config) fiber.Handler {
}
// Simple request
if c.Method() != http.MethodOptions {
if c.Method() != fiber.MethodOptions {
c.Vary(fiber.HeaderOrigin)
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)

View File

@ -6,6 +6,7 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
@ -237,7 +238,7 @@ func Test_CORS_Next(t *testing.T) {
},
}))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
}

View File

@ -1,6 +1,8 @@
package cors
import "strings"
import (
"strings"
)
func matchScheme(domain, pattern string) bool {
didx := strings.Index(domain, ":")
@ -20,18 +22,20 @@ func matchSubdomain(domain, pattern string) bool {
}
domAuth := domain[didx+3:]
// to avoid long loop by invalid long domain
if len(domAuth) > 253 {
const maxDomainLen = 253
if len(domAuth) > maxDomainLen {
return false
}
patAuth := pattern[pidx+3:]
domComp := strings.Split(domAuth, ".")
patComp := strings.Split(patAuth, ".")
for i := len(domComp)/2 - 1; i >= 0; i-- {
const divHalf = 2
for i := len(domComp)/divHalf - 1; i >= 0; i-- {
opp := len(domComp) - 1 - i
domComp[i], domComp[opp] = domComp[opp], domComp[i]
}
for i := len(patComp)/2 - 1; i >= 0; i-- {
for i := len(patComp)/divHalf - 1; i >= 0; i-- {
opp := len(patComp) - 1 - i
patComp[i], patComp[opp] = patComp[opp], patComp[i]
}

View File

@ -1,7 +1,7 @@
package csrf
import (
"fmt"
"log"
"net/textproto"
"strings"
"time"
@ -80,13 +80,13 @@ type Config struct {
// Optional. Default: utils.UUID
KeyGenerator func() string
// Deprecated, please use Expiration
// Deprecated: Please use Expiration
CookieExpires time.Duration
// Deprecated, please use Cookie* related fields
// Deprecated: Please use Cookie* related fields
Cookie *fiber.Cookie
// Deprecated, please use KeyLookup
// Deprecated: Please use KeyLookup
TokenLookup string
// ErrorHandler is executed when an error is returned from fiber.Handler.
@ -105,6 +105,8 @@ type Config struct {
const HeaderName = "X-Csrf-Token"
// ConfigDefault is the default config
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{
KeyLookup: "header:" + HeaderName,
CookieName: "csrf_",
@ -116,7 +118,7 @@ var ConfigDefault = Config{
}
// default ErrorHandler that process return error from fiber.Handler
var defaultErrorHandler = func(c *fiber.Ctx, err error) error {
func defaultErrorHandler(_ *fiber.Ctx, _ error) error {
return fiber.ErrForbidden
}
@ -132,15 +134,15 @@ func configDefault(config ...Config) Config {
// Set default values
if cfg.TokenLookup != "" {
fmt.Println("[CSRF] TokenLookup is deprecated, please use KeyLookup")
log.Printf("[CSRF] TokenLookup is deprecated, please use KeyLookup\n")
cfg.KeyLookup = cfg.TokenLookup
}
if int(cfg.CookieExpires.Seconds()) > 0 {
fmt.Println("[CSRF] CookieExpires is deprecated, please use Expiration")
log.Printf("[CSRF] CookieExpires is deprecated, please use Expiration\n")
cfg.Expiration = cfg.CookieExpires
}
if cfg.Cookie != nil {
fmt.Println("[CSRF] Cookie is deprecated, please use Cookie* related fields")
log.Printf("[CSRF] Cookie is deprecated, please use Cookie* related fields\n")
if cfg.Cookie.Name != "" {
cfg.CookieName = cfg.Cookie.Name
}
@ -178,7 +180,8 @@ func configDefault(config ...Config) Config {
// Generate the correct extractor to get the token from the correct location
selectors := strings.Split(cfg.KeyLookup, ":")
if len(selectors) != 2 {
const numParts = 2
if len(selectors) != numParts {
panic("[CSRF] KeyLookup must in the form of <source>:<key>")
}

View File

@ -7,9 +7,7 @@ import (
"github.com/gofiber/fiber/v2"
)
var (
errTokenNotFound = errors.New("csrf token not found")
)
var errTokenNotFound = errors.New("csrf token not found")
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
@ -22,7 +20,7 @@ func New(config ...Config) fiber.Handler {
dummyValue := []byte{'+'}
// Return new handler
return func(c *fiber.Ctx) (err error) {
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
@ -39,7 +37,7 @@ func New(config ...Config) fiber.Handler {
// Assume that anything not defined as 'safe' by RFC7231 needs protection
// Extract token from client request i.e. header, query, param, form or cookie
token, err = cfg.Extractor(c)
token, err := cfg.Extractor(c)
if err != nil {
return cfg.ErrorHandler(c, err)
}

View File

@ -7,6 +7,7 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
@ -23,7 +24,7 @@ func Test_CSRF(t *testing.T) {
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
methods := [4]string{"GET", "HEAD", "OPTIONS", "TRACE"}
methods := [4]string{fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace}
for _, method := range methods {
// Generate CSRF token
@ -33,14 +34,14 @@ func Test_CSRF(t *testing.T) {
// Without CSRF cookie
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.SetMethod(fiber.MethodPost)
h(ctx)
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
// Empty/invalid CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(HeaderName, "johndoe")
h(ctx)
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
@ -55,7 +56,7 @@ func Test_CSRF(t *testing.T) {
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(HeaderName, token)
h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
@ -72,7 +73,7 @@ func Test_CSRF_Next(t *testing.T) {
},
}))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
}
@ -92,7 +93,7 @@ func Test_CSRF_Invalid_KeyLookup(t *testing.T) {
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("GET")
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
}
@ -110,7 +111,7 @@ func Test_CSRF_From_Form(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
// Invalid CSRF token
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm)
h(ctx)
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
@ -118,12 +119,12 @@ func Test_CSRF_From_Form(t *testing.T) {
// Generate CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod("GET")
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm)
ctx.Request.SetBodyString("_csrf=" + token)
h(ctx)
@ -144,7 +145,7 @@ func Test_CSRF_From_Query(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
// Invalid CSRF token
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.SetRequestURI("/?_csrf=" + utils.UUID())
h(ctx)
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
@ -152,7 +153,7 @@ func Test_CSRF_From_Query(t *testing.T) {
// Generate CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod("GET")
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.SetRequestURI("/")
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
@ -161,7 +162,7 @@ func Test_CSRF_From_Query(t *testing.T) {
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.SetRequestURI("/?_csrf=" + token)
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.SetMethod(fiber.MethodPost)
h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
utils.AssertEqual(t, "OK", string(ctx.Response.Body()))
@ -181,7 +182,7 @@ func Test_CSRF_From_Param(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
// Invalid CSRF token
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.SetRequestURI("/" + utils.UUID())
h(ctx)
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
@ -189,7 +190,7 @@ func Test_CSRF_From_Param(t *testing.T) {
// Generate CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod("GET")
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.SetRequestURI("/" + utils.UUID())
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
@ -198,7 +199,7 @@ func Test_CSRF_From_Param(t *testing.T) {
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.SetRequestURI("/" + token)
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.SetMethod(fiber.MethodPost)
h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
utils.AssertEqual(t, "OK", string(ctx.Response.Body()))
@ -218,7 +219,7 @@ func Test_CSRF_From_Cookie(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
// Invalid CSRF token
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.SetRequestURI("/")
ctx.Request.Header.Set(fiber.HeaderCookie, "csrf="+utils.UUID()+";")
h(ctx)
@ -227,7 +228,7 @@ func Test_CSRF_From_Cookie(t *testing.T) {
// Generate CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod("GET")
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.SetRequestURI("/")
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
@ -235,7 +236,7 @@ func Test_CSRF_From_Cookie(t *testing.T) {
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderCookie, "csrf="+token+";")
ctx.Request.SetRequestURI("/")
h(ctx)
@ -268,7 +269,7 @@ func Test_CSRF_From_Custom(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
// Invalid CSRF token
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain)
h(ctx)
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
@ -276,12 +277,12 @@ func Test_CSRF_From_Custom(t *testing.T) {
// Generate CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod("GET")
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain)
ctx.Request.SetBodyString("_csrf=" + token)
h(ctx)
@ -307,13 +308,13 @@ func Test_CSRF_ErrorHandler_InvalidToken(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
// Generate CSRF token
ctx.Request.Header.SetMethod("GET")
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
// invalid CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(HeaderName, "johndoe")
h(ctx)
utils.AssertEqual(t, 419, ctx.Response.StatusCode())
@ -339,69 +340,69 @@ func Test_CSRF_ErrorHandler_EmptyToken(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
// Generate CSRF token
ctx.Request.Header.SetMethod("GET")
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
// empty CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.SetMethod(fiber.MethodPost)
h(ctx)
utils.AssertEqual(t, 419, ctx.Response.StatusCode())
utils.AssertEqual(t, "empty CSRF token", string(ctx.Response.Body()))
}
// TODO: use this test case and make the unsafe header value bug from https://github.com/gofiber/fiber/issues/2045 reproducible and permanently fixed/tested by this testcase
//func Test_CSRF_UnsafeHeaderValue(t *testing.T) {
// func Test_CSRF_UnsafeHeaderValue(t *testing.T) {
// t.Parallel()
// app := fiber.New()
//
// app.Use(New())
// app.Get("/", func(c *fiber.Ctx) error {
// return c.SendStatus(fiber.StatusOK)
// })
// app.Get("/test", func(c *fiber.Ctx) error {
// return c.SendStatus(fiber.StatusOK)
// })
// app.Post("/", func(c *fiber.Ctx) error {
// return c.SendStatus(fiber.StatusOK)
// })
//
// resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
// utils.AssertEqual(t, nil, err)
// utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
//
// var token string
// for _, c := range resp.Cookies() {
// if c.Name != ConfigDefault.CookieName {
// continue
// }
// token = c.Value
// break
// }
//
// fmt.Println("token", token)
//
// getReq := httptest.NewRequest(http.MethodGet, "/", nil)
// getReq.Header.Set(HeaderName, token)
// resp, err = app.Test(getReq)
//
// getReq = httptest.NewRequest(http.MethodGet, "/test", nil)
// getReq.Header.Set("X-Requested-With", "XMLHttpRequest")
// getReq.Header.Set(fiber.HeaderCacheControl, "no")
// getReq.Header.Set(HeaderName, token)
//
// resp, err = app.Test(getReq)
//
// getReq.Header.Set(fiber.HeaderAccept, "*/*")
// getReq.Header.Del(HeaderName)
// resp, err = app.Test(getReq)
//
// postReq := httptest.NewRequest(http.MethodPost, "/", nil)
// postReq.Header.Set("X-Requested-With", "XMLHttpRequest")
// postReq.Header.Set(HeaderName, token)
// resp, err = app.Test(postReq)
//}
// app := fiber.New()
// app.Use(New())
// app.Get("/", func(c *fiber.Ctx) error {
// return c.SendStatus(fiber.StatusOK)
// })
// app.Get("/test", func(c *fiber.Ctx) error {
// return c.SendStatus(fiber.StatusOK)
// })
// app.Post("/", func(c *fiber.Ctx) error {
// return c.SendStatus(fiber.StatusOK)
// })
// resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
// utils.AssertEqual(t, nil, err)
// utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
// var token string
// for _, c := range resp.Cookies() {
// if c.Name != ConfigDefault.CookieName {
// continue
// }
// token = c.Value
// break
// }
// fmt.Println("token", token)
// getReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
// getReq.Header.Set(HeaderName, token)
// resp, err = app.Test(getReq)
// getReq = httptest.NewRequest(fiber.MethodGet, "/test", nil)
// getReq.Header.Set("X-Requested-With", "XMLHttpRequest")
// getReq.Header.Set(fiber.HeaderCacheControl, "no")
// getReq.Header.Set(HeaderName, token)
// resp, err = app.Test(getReq)
// getReq.Header.Set(fiber.HeaderAccept, "*/*")
// getReq.Header.Del(HeaderName)
// resp, err = app.Test(getReq)
// postReq := httptest.NewRequest(fiber.MethodPost, "/", nil)
// postReq.Header.Set("X-Requested-With", "XMLHttpRequest")
// postReq.Header.Set(HeaderName, token)
// resp, err = app.Test(postReq)
// }
// go test -v -run=^$ -bench=Benchmark_Middleware_CSRF_Check -benchmem -count=4
func Benchmark_Middleware_CSRF_Check(b *testing.B) {
@ -417,12 +418,12 @@ func Benchmark_Middleware_CSRF_Check(b *testing.B) {
ctx := &fasthttp.RequestCtx{}
// Generate CSRF token
ctx.Request.Header.SetMethod("GET")
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(HeaderName, token)
b.ReportAllocs()
@ -449,7 +450,7 @@ func Benchmark_Middleware_CSRF_GenerateToken(b *testing.B) {
ctx := &fasthttp.RequestCtx{}
// Generate CSRF token
ctx.Request.Header.SetMethod("GET")
ctx.Request.Header.SetMethod(fiber.MethodGet)
b.ReportAllocs()
b.ResetTimer()

View File

@ -41,74 +41,23 @@ func newManager(storage fiber.Storage) *manager {
return manager
}
// acquire returns an *entry from the sync.Pool
func (m *manager) acquire() *item {
return m.pool.Get().(*item)
}
// release and reset *entry to sync.Pool
func (m *manager) release(e *item) {
// don't release item if we using memory storage
if m.storage != nil {
return
}
m.pool.Put(e)
}
// get data from storage or memory
func (m *manager) get(key string) (it *item) {
if m.storage != nil {
it = m.acquire()
if raw, _ := m.storage.Get(key); raw != nil {
if _, err := it.UnmarshalMsg(raw); err != nil {
return
}
}
return
}
if it, _ = m.memory.Get(key).(*item); it == nil {
it = m.acquire()
}
return
}
// get raw data from storage or memory
func (m *manager) getRaw(key string) (raw []byte) {
func (m *manager) getRaw(key string) []byte {
var raw []byte
if m.storage != nil {
raw, _ = m.storage.Get(key)
raw, _ = m.storage.Get(key) //nolint:errcheck // TODO: Do not ignore error
} else {
raw, _ = m.memory.Get(key).([]byte)
}
return
}
// set data to storage or memory
func (m *manager) set(key string, it *item, exp time.Duration) {
if m.storage != nil {
if raw, err := it.MarshalMsg(nil); err == nil {
_ = m.storage.Set(key, raw, exp)
}
} else {
// the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here
m.memory.Set(utils.CopyString(key), it, exp)
raw, _ = m.memory.Get(key).([]byte) //nolint:errcheck // TODO: Do not ignore error
}
return raw
}
// set data to storage or memory
func (m *manager) setRaw(key string, raw []byte, exp time.Duration) {
if m.storage != nil {
_ = m.storage.Set(key, raw, exp)
_ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Do not ignore error
} else {
// the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here
m.memory.Set(utils.CopyString(key), raw, exp)
}
}
// delete data from storage or memory
func (m *manager) delete(key string) {
if m.storage != nil {
_ = m.storage.Delete(key)
} else {
m.memory.Delete(key)
}
}

View File

@ -1,6 +1,8 @@
package encryptcookie
import "github.com/gofiber/fiber/v2"
import (
"github.com/gofiber/fiber/v2"
)
// Config defines the config for middleware.
type Config struct {
@ -32,6 +34,8 @@ type Config struct {
}
// ConfigDefault is the default config
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{
Next: nil,
Except: []string{"csrf_"},

View File

@ -2,6 +2,7 @@ package encryptcookie
import (
"github.com/gofiber/fiber/v2"
"github.com/valyala/fasthttp"
)

View File

@ -7,9 +7,11 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
//nolint:gochecknoglobals // Using a global var is fine here
var testKey = GenerateKey()
func Test_Middleware_Encrypt_Cookie(t *testing.T) {
@ -35,14 +37,14 @@ func Test_Middleware_Encrypt_Cookie(t *testing.T) {
// Test empty cookie
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("GET")
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
utils.AssertEqual(t, "value=", string(ctx.Response.Body()))
// Test invalid cookie
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("GET")
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("test", "Invalid")
h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
@ -54,18 +56,19 @@ func Test_Middleware_Encrypt_Cookie(t *testing.T) {
// Test valid cookie
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.SetMethod(fiber.MethodPost)
h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
encryptedCookie := fasthttp.Cookie{}
encryptedCookie.SetKey("test")
utils.AssertEqual(t, true, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value")
decryptedCookieValue, _ := DecryptCookie(string(encryptedCookie.Value()), testKey)
decryptedCookieValue, err := DecryptCookie(string(encryptedCookie.Value()), testKey)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "SomeThing", decryptedCookieValue)
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("GET")
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("test", string(encryptedCookie.Value()))
h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
@ -91,7 +94,7 @@ func Test_Encrypt_Cookie_Next(t *testing.T) {
return nil
})
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "SomeThing", resp.Cookies()[0].Value)
}
@ -123,7 +126,7 @@ func Test_Encrypt_Cookie_Except(t *testing.T) {
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("GET")
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
@ -135,7 +138,8 @@ func Test_Encrypt_Cookie_Except(t *testing.T) {
encryptedCookie := fasthttp.Cookie{}
encryptedCookie.SetKey("test2")
utils.AssertEqual(t, true, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value")
decryptedCookieValue, _ := DecryptCookie(string(encryptedCookie.Value()), testKey)
decryptedCookieValue, err := DecryptCookie(string(encryptedCookie.Value()), testKey)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "SomeThing", decryptedCookieValue)
}
@ -169,18 +173,19 @@ func Test_Encrypt_Cookie_Custom_Encryptor(t *testing.T) {
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.SetMethod(fiber.MethodPost)
h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
encryptedCookie := fasthttp.Cookie{}
encryptedCookie.SetKey("test")
utils.AssertEqual(t, true, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value")
decodedBytes, _ := base64.StdEncoding.DecodeString(string(encryptedCookie.Value()))
decodedBytes, err := base64.StdEncoding.DecodeString(string(encryptedCookie.Value()))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "SomeThing", string(decodedBytes))
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("GET")
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("test", string(encryptedCookie.Value()))
h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode())

View File

@ -6,47 +6,56 @@ import (
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"io"
)
// EncryptCookie Encrypts a cookie value with specific encryption key
func EncryptCookie(value, key string) (string, error) {
keyDecoded, _ := base64.StdEncoding.DecodeString(key)
plaintext := []byte(value)
keyDecoded, err := base64.StdEncoding.DecodeString(key)
if err != nil {
return "", fmt.Errorf("failed to base64-decode key: %w", err)
}
block, err := aes.NewCipher(keyDecoded)
if err != nil {
return "", err
return "", fmt.Errorf("failed to create AES cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
return "", fmt.Errorf("failed to create GCM mode: %w", err)
}
nonce := make([]byte, gcm.NonceSize())
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
return "", err
return "", fmt.Errorf("failed to read: %w", err)
}
ciphertext := gcm.Seal(nonce, nonce, plaintext, nil)
ciphertext := gcm.Seal(nonce, nonce, []byte(value), nil)
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// DecryptCookie Decrypts a cookie value with specific encryption key
func DecryptCookie(value, key string) (string, error) {
keyDecoded, _ := base64.StdEncoding.DecodeString(key)
enc, _ := base64.StdEncoding.DecodeString(value)
keyDecoded, err := base64.StdEncoding.DecodeString(key)
if err != nil {
return "", fmt.Errorf("failed to base64-decode key: %w", err)
}
enc, err := base64.StdEncoding.DecodeString(value)
if err != nil {
return "", fmt.Errorf("failed to base64-decode value: %w", err)
}
block, err := aes.NewCipher(keyDecoded)
if err != nil {
return "", err
return "", fmt.Errorf("failed to create AES cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
return "", fmt.Errorf("failed to create GCM mode: %w", err)
}
nonceSize := gcm.NonceSize()
@ -59,7 +68,7 @@ func DecryptCookie(value, key string) (string, error) {
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return "", err
return "", fmt.Errorf("failed to decrypt ciphertext: %w", err)
}
return string(plaintext), nil
@ -67,7 +76,8 @@ func DecryptCookie(value, key string) (string, error) {
// GenerateKey Generates an encryption key
func GenerateKey() string {
ret := make([]byte, 32)
const keyLen = 32
ret := make([]byte, keyLen)
if _, err := rand.Read(ret); err != nil {
panic(err)

View File

@ -23,10 +23,8 @@ func (envVar *EnvVar) set(key, val string) {
envVar.Vars[key] = val
}
var defaultConfig = Config{}
func New(config ...Config) fiber.Handler {
var cfg = defaultConfig
var cfg Config
if len(config) > 0 {
cfg = config[0]
}
@ -57,8 +55,9 @@ func newEnvVar(cfg Config) *EnvVar {
}
}
} else {
const numElems = 2
for _, envVal := range os.Environ() {
keyVal := strings.SplitN(envVal, "=", 2)
keyVal := strings.SplitN(envVal, "=", numElems)
if _, exists := cfg.ExcludeVars[keyVal[0]]; !exists {
vars.set(keyVal[0], keyVal[1])
}

View File

@ -1,6 +1,8 @@
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
package envvar
import (
"context"
"encoding/json"
"io"
"net/http"
@ -12,16 +14,25 @@ import (
)
func TestEnvVarStructWithExportVarsExcludeVars(t *testing.T) {
os.Setenv("testKey", "testEnvValue")
os.Setenv("anotherEnvKey", "anotherEnvVal")
os.Setenv("excludeKey", "excludeEnvValue")
defer os.Unsetenv("testKey")
defer os.Unsetenv("anotherEnvKey")
defer os.Unsetenv("excludeKey")
err := os.Setenv("testKey", "testEnvValue")
utils.AssertEqual(t, nil, err)
err = os.Setenv("anotherEnvKey", "anotherEnvVal")
utils.AssertEqual(t, nil, err)
err = os.Setenv("excludeKey", "excludeEnvValue")
utils.AssertEqual(t, nil, err)
defer func() {
err := os.Unsetenv("testKey")
utils.AssertEqual(t, nil, err)
err = os.Unsetenv("anotherEnvKey")
utils.AssertEqual(t, nil, err)
err = os.Unsetenv("excludeKey")
utils.AssertEqual(t, nil, err)
}()
vars := newEnvVar(Config{
ExportVars: map[string]string{"testKey": "", "testDefaultKey": "testDefaultVal"},
ExcludeVars: map[string]string{"excludeKey": ""}})
ExcludeVars: map[string]string{"excludeKey": ""},
})
utils.AssertEqual(t, vars.Vars["testKey"], "testEnvValue")
utils.AssertEqual(t, vars.Vars["testDefaultKey"], "testDefaultVal")
@ -30,21 +41,28 @@ func TestEnvVarStructWithExportVarsExcludeVars(t *testing.T) {
}
func TestEnvVarHandler(t *testing.T) {
os.Setenv("testKey", "testVal")
defer os.Unsetenv("testKey")
err := os.Setenv("testKey", "testVal")
utils.AssertEqual(t, nil, err)
defer func() {
err := os.Unsetenv("testKey")
utils.AssertEqual(t, nil, err)
}()
expectedEnvVarResponse, _ := json.Marshal(
expectedEnvVarResponse, err := json.Marshal(
struct {
Vars map[string]string `json:"vars"`
}{
map[string]string{"testKey": "testVal"},
})
utils.AssertEqual(t, nil, err)
app := fiber.New()
app.Use("/envvars", New(Config{
ExportVars: map[string]string{"testKey": ""}}))
ExportVars: map[string]string{"testKey": ""},
}))
req, _ := http.NewRequest("GET", "http://localhost/envvars", nil)
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars", nil)
utils.AssertEqual(t, nil, err)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
@ -57,14 +75,16 @@ func TestEnvVarHandler(t *testing.T) {
func TestEnvVarHandlerNotMatched(t *testing.T) {
app := fiber.New()
app.Use("/envvars", New(Config{
ExportVars: map[string]string{"testKey": ""}}))
ExportVars: map[string]string{"testKey": ""},
}))
app.Get("/another-path", func(ctx *fiber.Ctx) error {
utils.AssertEqual(t, nil, ctx.SendString("OK"))
return nil
})
req, _ := http.NewRequest("GET", "http://localhost/another-path", nil)
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/another-path", nil)
utils.AssertEqual(t, nil, err)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
@ -75,13 +95,18 @@ func TestEnvVarHandlerNotMatched(t *testing.T) {
}
func TestEnvVarHandlerDefaultConfig(t *testing.T) {
os.Setenv("testEnvKey", "testEnvVal")
defer os.Unsetenv("testEnvKey")
err := os.Setenv("testEnvKey", "testEnvVal")
utils.AssertEqual(t, nil, err)
defer func() {
err := os.Unsetenv("testEnvKey")
utils.AssertEqual(t, nil, err)
}()
app := fiber.New()
app.Use("/envvars", New())
req, _ := http.NewRequest("GET", "http://localhost/envvars", nil)
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars", nil)
utils.AssertEqual(t, nil, err)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
@ -98,7 +123,8 @@ func TestEnvVarHandlerMethod(t *testing.T) {
app := fiber.New()
app.Use("/envvars", New())
req, _ := http.NewRequest("POST", "http://localhost/envvars", nil)
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodPost, "http://localhost/envvars", nil)
utils.AssertEqual(t, nil, err)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusMethodNotAllowed, resp.StatusCode)
@ -107,14 +133,19 @@ func TestEnvVarHandlerMethod(t *testing.T) {
func TestEnvVarHandlerSpecialValue(t *testing.T) {
testEnvKey := "testEnvKey"
fakeBase64 := "testBase64:TQ=="
os.Setenv(testEnvKey, fakeBase64)
defer os.Unsetenv(testEnvKey)
err := os.Setenv(testEnvKey, fakeBase64)
utils.AssertEqual(t, nil, err)
defer func() {
err := os.Unsetenv(testEnvKey)
utils.AssertEqual(t, nil, err)
}()
app := fiber.New()
app.Use("/envvars", New())
app.Use("/envvars/export", New(Config{ExportVars: map[string]string{testEnvKey: ""}}))
req, _ := http.NewRequest("GET", "http://localhost/envvars", nil)
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars", nil)
utils.AssertEqual(t, nil, err)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
@ -126,7 +157,8 @@ func TestEnvVarHandlerSpecialValue(t *testing.T) {
val := envVars.Vars[testEnvKey]
utils.AssertEqual(t, fakeBase64, val)
req, _ = http.NewRequest("GET", "http://localhost/envvars/export", nil)
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars/export", nil)
utils.AssertEqual(t, nil, err)
resp, err = app.Test(req)
utils.AssertEqual(t, nil, err)

View File

@ -23,6 +23,8 @@ type Config struct {
}
// ConfigDefault is the default config
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{
Weak: false,
Next: nil,

View File

@ -5,12 +5,8 @@ import (
"hash/crc32"
"github.com/gofiber/fiber/v2"
"github.com/valyala/bytebufferpool"
)
var (
normalizedHeaderETag = []byte("Etag")
weakPrefix = []byte("W/")
"github.com/valyala/bytebufferpool"
)
// New creates a new middleware handler
@ -18,32 +14,38 @@ func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
crc32q := crc32.MakeTable(0xD5828281)
var (
normalizedHeaderETag = []byte("Etag")
weakPrefix = []byte("W/")
)
const crcPol = 0xD5828281
crc32q := crc32.MakeTable(crcPol)
// Return new handler
return func(c *fiber.Ctx) (err error) {
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Return err if next handler returns one
if err = c.Next(); err != nil {
return
if err := c.Next(); err != nil {
return err
}
// Don't generate ETags for invalid responses
if c.Response().StatusCode() != fiber.StatusOK {
return
return nil
}
body := c.Response().Body()
// Skips ETag if no response body is present
if len(body) == 0 {
return
return nil
}
// Skip ETag if header is already present
if c.Response().Header.PeekBytes(normalizedHeaderETag) != nil {
return
return nil
}
// Generate ETag for response
@ -52,14 +54,14 @@ func New(config ...Config) fiber.Handler {
// Enable weak tag
if cfg.Weak {
_, _ = bb.Write(weakPrefix)
_, _ = bb.Write(weakPrefix) //nolint:errcheck // This will never fail
}
_ = bb.WriteByte('"')
_ = bb.WriteByte('"') //nolint:errcheck // This will never fail
bb.B = appendUint(bb.Bytes(), uint32(len(body)))
_ = bb.WriteByte('-')
_ = bb.WriteByte('-') //nolint:errcheck // This will never fail
bb.B = appendUint(bb.Bytes(), crc32.Checksum(body, crc32q))
_ = bb.WriteByte('"')
_ = bb.WriteByte('"') //nolint:errcheck // This will never fail
etag := bb.Bytes()
@ -78,7 +80,7 @@ func New(config ...Config) fiber.Handler {
// W/1 != W/2 || W/1 != 2
c.Response().Header.SetCanonical(normalizedHeaderETag, etag)
return
return nil
}
if bytes.Contains(clientEtag, etag) {
@ -90,7 +92,7 @@ func New(config ...Config) fiber.Handler {
// 1 != 2
c.Response().Header.SetCanonical(normalizedHeaderETag, etag)
return
return nil
}
}
@ -102,7 +104,7 @@ func appendUint(dst []byte, n uint32) []byte {
var q uint32
for n >= 10 {
i--
q = n / 10
q = n / 10 //nolint:gomnd // TODO: Explain why we divide by 10 here
buf[i] = '0' + byte(n-q*10)
n = q
}

View File

@ -8,6 +8,7 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
@ -21,7 +22,7 @@ func Test_ETag_Next(t *testing.T) {
},
}))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
}
@ -37,7 +38,7 @@ func Test_ETag_SkipError(t *testing.T) {
return fiber.ErrForbidden
})
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusForbidden, resp.StatusCode)
}
@ -53,7 +54,7 @@ func Test_ETag_NotStatusOK(t *testing.T) {
return c.SendStatus(fiber.StatusCreated)
})
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusCreated, resp.StatusCode)
}
@ -69,7 +70,7 @@ func Test_ETag_NoBody(t *testing.T) {
return nil
})
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
}
@ -91,7 +92,7 @@ func Test_ETag_NewEtag(t *testing.T) {
})
}
func testETagNewEtag(t *testing.T, headerIfNoneMatch, matched bool) {
func testETagNewEtag(t *testing.T, headerIfNoneMatch, matched bool) { //nolint:revive // We're in a test, so using bools as a flow-control is fine
t.Helper()
app := fiber.New()
@ -102,7 +103,7 @@ func testETagNewEtag(t *testing.T, headerIfNoneMatch, matched bool) {
return c.SendString("Hello, World!")
})
req := httptest.NewRequest("GET", "/", nil)
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
if headerIfNoneMatch {
etag := `"non-match"`
if matched {
@ -145,7 +146,7 @@ func Test_ETag_WeakEtag(t *testing.T) {
})
}
func testETagWeakEtag(t *testing.T, headerIfNoneMatch, matched bool) {
func testETagWeakEtag(t *testing.T, headerIfNoneMatch, matched bool) { //nolint:revive // We're in a test, so using bools as a flow-control is fine
t.Helper()
app := fiber.New()
@ -156,7 +157,7 @@ func testETagWeakEtag(t *testing.T, headerIfNoneMatch, matched bool) {
return c.SendString("Hello, World!")
})
req := httptest.NewRequest("GET", "/", nil)
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
if headerIfNoneMatch {
etag := `W/"non-match"`
if matched {
@ -199,7 +200,7 @@ func Test_ETag_CustomEtag(t *testing.T) {
})
}
func testETagCustomEtag(t *testing.T, headerIfNoneMatch, matched bool) {
func testETagCustomEtag(t *testing.T, headerIfNoneMatch, matched bool) { //nolint:revive // We're in a test, so using bools as a flow-control is fine
t.Helper()
app := fiber.New()
@ -214,7 +215,7 @@ func testETagCustomEtag(t *testing.T, headerIfNoneMatch, matched bool) {
return c.SendString("Hello, World!")
})
req := httptest.NewRequest("GET", "/", nil)
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
if headerIfNoneMatch {
etag := `"non-match"`
if matched {
@ -255,7 +256,7 @@ func Test_ETag_CustomEtagPut(t *testing.T) {
return c.SendString("Hello, World!")
})
req := httptest.NewRequest("PUT", "/", nil)
req := httptest.NewRequest(fiber.MethodPut, "/", nil)
req.Header.Set(fiber.HeaderIfMatch, `"non-match"`)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
@ -275,7 +276,7 @@ func Benchmark_Etag(b *testing.B) {
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod("GET")
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/")
b.ReportAllocs()

View File

@ -1,6 +1,8 @@
package expvar
import "github.com/gofiber/fiber/v2"
import (
"github.com/gofiber/fiber/v2"
)
// Config defines the config for middleware.
type Config struct {
@ -10,6 +12,7 @@ type Config struct {
Next func(c *fiber.Ctx) bool
}
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{
Next: nil,
}

View File

@ -4,6 +4,7 @@ import (
"strings"
"github.com/gofiber/fiber/v2"
"github.com/valyala/fasthttp/expvarhandler"
)
@ -29,6 +30,6 @@ func New(config ...Config) fiber.Handler {
return nil
}
return c.Redirect("/debug/vars", 302)
return c.Redirect("/debug/vars", fiber.StatusFound)
}
}

View File

@ -34,6 +34,8 @@ type Config struct {
}
// ConfigDefault is the default config
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{
Next: nil,
File: "",

View File

@ -1,16 +1,18 @@
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
package favicon
import (
"fmt"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"github.com/valyala/fasthttp"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
// go test -run Test_Middleware_Favicon
@ -25,22 +27,22 @@ func Test_Middleware_Favicon(t *testing.T) {
})
// Skip Favicon middleware
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
resp, err = app.Test(httptest.NewRequest("GET", "/favicon.ico", nil))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusNoContent, resp.StatusCode, "Status code")
resp, err = app.Test(httptest.NewRequest("OPTIONS", "/favicon.ico", nil))
resp, err = app.Test(httptest.NewRequest(fiber.MethodOptions, "/favicon.ico", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
resp, err = app.Test(httptest.NewRequest("PUT", "/favicon.ico", nil))
resp, err = app.Test(httptest.NewRequest(fiber.MethodPut, "/favicon.ico", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusMethodNotAllowed, resp.StatusCode, "Status code")
utils.AssertEqual(t, "GET, HEAD, OPTIONS", resp.Header.Get(fiber.HeaderAllow))
utils.AssertEqual(t, strings.Join([]string{fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions}, ", "), resp.Header.Get(fiber.HeaderAllow))
}
// go test -run Test_Middleware_Favicon_Not_Found
@ -70,8 +72,7 @@ func Test_Middleware_Favicon_Found(t *testing.T) {
return nil
})
resp, err := app.Test(httptest.NewRequest("GET", "/favicon.ico", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
utils.AssertEqual(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType))
@ -83,15 +84,15 @@ func Test_Middleware_Favicon_Found(t *testing.T) {
// TODO use os.Dir if fiber upgrades to 1.16
type mockFS struct{}
func (m mockFS) Open(name string) (http.File, error) {
func (mockFS) Open(name string) (http.File, error) {
if name == "/" {
name = "."
} else {
name = strings.TrimPrefix(name, "/")
}
file, err := os.Open(name)
file, err := os.Open(name) //nolint:gosec // We're in a test func, so this is fine
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to open: %w", err)
}
return file, nil
}
@ -106,7 +107,7 @@ func Test_Middleware_Favicon_FileSystem(t *testing.T) {
FileSystem: mockFS{},
}))
resp, err := app.Test(httptest.NewRequest("GET", "/favicon.ico", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
utils.AssertEqual(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType))
@ -123,7 +124,7 @@ func Test_Middleware_Favicon_CacheControl(t *testing.T) {
File: "../../.github/testdata/favicon.ico",
}))
resp, err := app.Test(httptest.NewRequest("GET", "/favicon.ico", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
utils.AssertEqual(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType))
@ -159,7 +160,7 @@ func Test_Favicon_Next(t *testing.T) {
},
}))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
}

View File

@ -1,6 +1,7 @@
package filesystem
import (
"fmt"
"net/http"
"os"
"strconv"
@ -55,6 +56,8 @@ type Config struct {
}
// ConfigDefault is the default config
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{
Next: nil,
Root: nil,
@ -102,7 +105,7 @@ func New(config ...Config) fiber.Handler {
cacheControlStr := "public, max-age=" + strconv.Itoa(cfg.MaxAge)
// Return new handler
return func(c *fiber.Ctx) (err error) {
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
@ -131,28 +134,23 @@ func New(config ...Config) fiber.Handler {
path = cfg.PathPrefix + path
}
var (
file http.File
stat os.FileInfo
)
if len(path) > 1 {
path = utils.TrimRight(path, '/')
}
file, err = cfg.Root.Open(path)
file, err := cfg.Root.Open(path)
if err != nil && os.IsNotExist(err) && cfg.NotFoundFile != "" {
file, err = cfg.Root.Open(cfg.NotFoundFile)
}
if err != nil {
if os.IsNotExist(err) {
return c.Status(fiber.StatusNotFound).Next()
}
return
return fmt.Errorf("failed to open: %w", err)
}
if stat, err = file.Stat(); err != nil {
return
stat, err := file.Stat()
if err != nil {
return fmt.Errorf("failed to stat: %w", err)
}
// Serve index if path is directory
@ -200,7 +198,7 @@ func New(config ...Config) fiber.Handler {
c.Response().SkipBody = true
c.Response().Header.SetContentLength(contentLength)
if err := file.Close(); err != nil {
return err
return fmt.Errorf("failed to close: %w", err)
}
return nil
}
@ -210,22 +208,18 @@ func New(config ...Config) fiber.Handler {
}
// SendFile ...
func SendFile(c *fiber.Ctx, fs http.FileSystem, path string) (err error) {
var (
file http.File
stat os.FileInfo
)
file, err = fs.Open(path)
func SendFile(c *fiber.Ctx, fs http.FileSystem, path string) error {
file, err := fs.Open(path)
if err != nil {
if os.IsNotExist(err) {
return fiber.ErrNotFound
}
return err
return fmt.Errorf("failed to open: %w", err)
}
if stat, err = file.Stat(); err != nil {
return err
stat, err := file.Stat()
if err != nil {
return fmt.Errorf("failed to stat: %w", err)
}
// Serve index if path is directory
@ -268,7 +262,7 @@ func SendFile(c *fiber.Ctx, fs http.FileSystem, path string) (err error) {
c.Response().SkipBody = true
c.Response().Header.SetContentLength(contentLength)
if err := file.Close(); err != nil {
return err
return fmt.Errorf("failed to close: %w", err)
}
return nil
}

View File

@ -1,6 +1,8 @@
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
package filesystem
import (
"context"
"net/http"
"net/http/httptest"
"testing"
@ -119,7 +121,7 @@ func Test_FileSystem(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
resp, err := app.Test(httptest.NewRequest("GET", tt.url, nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, tt.url, nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, tt.statusCode, resp.StatusCode)
@ -142,7 +144,7 @@ func Test_FileSystem_Next(t *testing.T) {
},
}))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
}
@ -168,7 +170,8 @@ func Test_FileSystem_Head(t *testing.T) {
Root: http.Dir("../../.github/testdata/fs"),
}))
req, _ := http.NewRequest(fiber.MethodHead, "/test", nil)
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodHead, "/test", nil)
utils.AssertEqual(t, nil, err)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
@ -182,7 +185,8 @@ func Test_FileSystem_NoRoot(t *testing.T) {
app := fiber.New()
app.Use(New())
_, _ = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
}
func Test_FileSystem_UsingParam(t *testing.T) {
@ -193,7 +197,8 @@ func Test_FileSystem_UsingParam(t *testing.T) {
return SendFile(c, http.Dir("../../.github/testdata/fs"), c.Params("path")+".html")
})
req, _ := http.NewRequest(fiber.MethodHead, "/index", nil)
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodHead, "/index", nil)
utils.AssertEqual(t, nil, err)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
@ -207,7 +212,8 @@ func Test_FileSystem_UsingParam_NonFile(t *testing.T) {
return SendFile(c, http.Dir("../../.github/testdata/fs"), c.Params("path")+".html")
})
req, _ := http.NewRequest(fiber.MethodHead, "/template", nil)
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodHead, "/template", nil)
utils.AssertEqual(t, nil, err)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 404, resp.StatusCode)

View File

@ -13,18 +13,18 @@ import (
"github.com/gofiber/fiber/v2/utils"
)
func getFileExtension(path string) string {
n := strings.LastIndexByte(path, '.')
func getFileExtension(p string) string {
n := strings.LastIndexByte(p, '.')
if n < 0 {
return ""
}
return path[n:]
return p[n:]
}
func dirList(c *fiber.Ctx, f http.File) error {
fileinfos, err := f.Readdir(-1)
if err != nil {
return err
return fmt.Errorf("failed to read dir: %w", err)
}
fm := make(map[string]os.FileInfo, len(fileinfos))
@ -36,13 +36,13 @@ func dirList(c *fiber.Ctx, f http.File) error {
}
basePathEscaped := html.EscapeString(c.Path())
fmt.Fprintf(c, "<html><head><title>%s</title><style>.dir { font-weight: bold }</style></head><body>", basePathEscaped)
fmt.Fprintf(c, "<h1>%s</h1>", basePathEscaped)
fmt.Fprint(c, "<ul>")
_, _ = fmt.Fprintf(c, "<html><head><title>%s</title><style>.dir { font-weight: bold }</style></head><body>", basePathEscaped)
_, _ = fmt.Fprintf(c, "<h1>%s</h1>", basePathEscaped)
_, _ = fmt.Fprint(c, "<ul>")
if len(basePathEscaped) > 1 {
parentPathEscaped := html.EscapeString(utils.TrimRight(c.Path(), '/') + "/..")
fmt.Fprintf(c, `<li><a href="%s" class="dir">..</a></li>`, parentPathEscaped)
_, _ = fmt.Fprintf(c, `<li><a href="%s" class="dir">..</a></li>`, parentPathEscaped)
}
sort.Strings(filenames)
@ -55,10 +55,10 @@ func dirList(c *fiber.Ctx, f http.File) error {
auxStr = fmt.Sprintf("file, %d bytes", fi.Size())
className = "file"
}
fmt.Fprintf(c, `<li><a href="%s" class="%s">%s</a>, %s, last modified %s</li>`,
_, _ = fmt.Fprintf(c, `<li><a href="%s" class="%s">%s</a>, %s, last modified %s</li>`,
pathEscaped, className, html.EscapeString(name), auxStr, fi.ModTime())
}
fmt.Fprint(c, "</ul></body></html>")
_, _ = fmt.Fprint(c, "</ul></body></html>")
c.Type("html")

View File

@ -9,9 +9,7 @@ import (
"github.com/gofiber/fiber/v2/internal/storage/memory"
)
var (
ErrInvalidIdempotencyKey = errors.New("invalid idempotency key")
)
var ErrInvalidIdempotencyKey = errors.New("invalid idempotency key")
// Config defines the config for middleware.
type Config struct {
@ -51,13 +49,15 @@ type Config struct {
}
// ConfigDefault is the default config
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{
Next: func(c *fiber.Ctx) bool {
// Skip middleware if the request was done using a safe HTTP method
return fiber.IsMethodSafe(c.Method())
},
Lifetime: 30 * time.Minute,
Lifetime: 30 * time.Minute, //nolint:gomnd // No magic number, just the default config
KeyHeader: "X-Idempotency-Key",
KeyHeaderValidate: func(k string) error {
@ -112,7 +112,7 @@ func configDefault(config ...Config) Config {
if cfg.Storage == nil {
cfg.Storage = memory.New(memory.Config{
GCInterval: cfg.Lifetime / 2,
GCInterval: cfg.Lifetime / 2, //nolint:gomnd // Half the lifetime interval
})
}

View File

@ -1,3 +1,4 @@
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
package idempotency_test
import (
@ -14,6 +15,7 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/idempotency"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
@ -172,5 +174,4 @@ func Benchmark_Idempotency(b *testing.B) {
h(c)
}
})
}

View File

@ -1,7 +1,7 @@
package limiter
import (
"fmt"
"log"
"time"
"github.com/gofiber/fiber/v2"
@ -58,19 +58,21 @@ type Config struct {
// Default: a new Fixed Window Rate Limiter
LimiterMiddleware LimiterHandler
// DEPRECATED: Use Expiration instead
// Deprecated: Use Expiration instead
Duration time.Duration
// DEPRECATED, use Storage instead
// Deprecated: Use Storage instead
Store fiber.Storage
// DEPRECATED, use KeyGenerator instead
// Deprecated: Use KeyGenerator instead
Key func(*fiber.Ctx) string
}
// ConfigDefault is the default config
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{
Max: 5,
Max: 5, //nolint:gomnd // No magic number, just the default config
Expiration: 1 * time.Minute,
KeyGenerator: func(c *fiber.Ctx) string {
return c.IP()
@ -95,15 +97,15 @@ func configDefault(config ...Config) Config {
// Set default values
if int(cfg.Duration.Seconds()) > 0 {
fmt.Println("[LIMITER] Duration is deprecated, please use Expiration")
log.Printf("[LIMITER] Duration is deprecated, please use Expiration\n")
cfg.Expiration = cfg.Duration
}
if cfg.Key != nil {
fmt.Println("[LIMITER] Key is deprecated, please us KeyGenerator")
log.Printf("[LIMITER] Key is deprecated, please us KeyGenerator\n")
cfg.KeyGenerator = cfg.Key
}
if cfg.Store != nil {
fmt.Println("[LIMITER] Store is deprecated, please use Storage")
log.Printf("[LIMITER] Store is deprecated, please use Storage\n")
cfg.Storage = cfg.Store
}
if cfg.Next == nil {

View File

@ -2,7 +2,6 @@ package limiter
import (
"io"
"net/http"
"net/http/httptest"
"sync"
"testing"
@ -11,6 +10,7 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/storage/memory"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
@ -34,7 +34,7 @@ func Test_Limiter_Concurrency_Store(t *testing.T) {
var wg sync.WaitGroup
singleRequest := func(wg *sync.WaitGroup) {
defer wg.Done()
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
@ -50,13 +50,13 @@ func Test_Limiter_Concurrency_Store(t *testing.T) {
wg.Wait()
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
}
@ -80,7 +80,7 @@ func Test_Limiter_Concurrency(t *testing.T) {
var wg sync.WaitGroup
singleRequest := func(wg *sync.WaitGroup) {
defer wg.Done()
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
@ -96,13 +96,13 @@ func Test_Limiter_Concurrency(t *testing.T) {
wg.Wait()
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
}
@ -120,21 +120,21 @@ func Test_Limiter_No_Skip_Choices(t *testing.T) {
}))
app.Get("/:status", func(c *fiber.Ctx) error {
if c.Params("status") == "fail" {
if c.Params("status") == "fail" { //nolint:goconst // False positive
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
}
@ -157,21 +157,21 @@ func Test_Limiter_Skip_Failed_Requests(t *testing.T) {
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
}
@ -196,21 +196,21 @@ func Test_Limiter_Skip_Successful_Requests(t *testing.T) {
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
}
@ -232,7 +232,7 @@ func Benchmark_Limiter_Custom_Store(b *testing.B) {
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod("GET")
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/")
b.ResetTimer()
@ -252,7 +252,7 @@ func Test_Limiter_Next(t *testing.T) {
},
}))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
}
@ -271,7 +271,7 @@ func Test_Limiter_Headers(t *testing.T) {
})
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod("GET")
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/")
app.Handler()(fctx)
@ -301,7 +301,7 @@ func Benchmark_Limiter(b *testing.B) {
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod("GET")
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/")
b.ResetTimer()
@ -327,7 +327,7 @@ func Test_Sliding_Window(t *testing.T) {
})
singleRequest := func(shouldFail bool) {
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
if shouldFail {
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)

View File

@ -46,7 +46,7 @@ func newManager(storage fiber.Storage) *manager {
// acquire returns an *entry from the sync.Pool
func (m *manager) acquire() *item {
return m.pool.Get().(*item)
return m.pool.Get().(*item) //nolint:forcetypeassert // We store nothing else in the pool
}
// release and reset *entry to sync.Pool
@ -58,37 +58,33 @@ func (m *manager) release(e *item) {
}
// get data from storage or memory
func (m *manager) get(key string) (it *item) {
func (m *manager) get(key string) *item {
var it *item
if m.storage != nil {
it = m.acquire()
if raw, _ := m.storage.Get(key); raw != nil {
raw, err := m.storage.Get(key)
if err != nil {
return it
}
if raw != nil {
if _, err := it.UnmarshalMsg(raw); err != nil {
return
return it
}
}
return
return it
}
if it, _ = m.memory.Get(key).(*item); it == nil {
if it, _ = m.memory.Get(key).(*item); it == nil { //nolint:errcheck // We store nothing else in the pool
it = m.acquire()
return it
}
return
}
// get raw data from storage or memory
func (m *manager) getRaw(key string) (raw []byte) {
if m.storage != nil {
raw, _ = m.storage.Get(key)
} else {
raw, _ = m.memory.Get(key).([]byte)
}
return
return it
}
// set data to storage or memory
func (m *manager) set(key string, it *item, exp time.Duration) {
if m.storage != nil {
if raw, err := it.MarshalMsg(nil); err == nil {
_ = m.storage.Set(key, raw, exp)
_ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Handle error here
}
// we can release data because it's serialized to database
m.release(it)
@ -96,21 +92,3 @@ func (m *manager) set(key string, it *item, exp time.Duration) {
m.memory.Set(key, it, exp)
}
}
// set data to storage or memory
func (m *manager) setRaw(key string, raw []byte, exp time.Duration) {
if m.storage != nil {
_ = m.storage.Set(key, raw, exp)
} else {
m.memory.Set(key, raw, exp)
}
}
// delete data from storage or memory
func (m *manager) delete(key string) {
if m.storage != nil {
_ = m.storage.Delete(key)
} else {
m.memory.Delete(key)
}
}

View File

@ -95,7 +95,7 @@ app.Use(logger.New(logger.Config{
TimeZone: "Asia/Shanghai",
Done: func(c *fiber.Ctx, logString []byte) {
if c.Response().StatusCode() != fiber.StatusOK {
reporter.SendToSlack(logString)
reporter.SendToSlack(logString)
}
},
}))
@ -189,7 +189,7 @@ const (
TagBytesReceived = "bytesReceived"
TagRoute = "route"
TagError = "error"
// DEPRECATED: Use TagReqHeader instead
// Deprecated: Use TagReqHeader instead
TagHeader = "header:" // request header
TagReqHeader = "reqHeader:" // request header
TagRespHeader = "respHeader:" // response header

View File

@ -79,13 +79,15 @@ type Buffer interface {
type LogFunc func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error)
// ConfigDefault is the default config
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{
Next: nil,
Done: nil,
Format: "[${time}] ${status} - ${latency} ${method} ${path}\n",
TimeFormat: "15:04:05",
TimeZone: "Local",
TimeInterval: 500 * time.Millisecond,
TimeInterval: 500 * time.Millisecond, //nolint:gomnd // No magic number, just the default config
Output: os.Stdout,
enableColors: true,
}

View File

@ -1,13 +1,10 @@
package logger
import (
"sync"
"sync/atomic"
"time"
)
var DataPool = sync.Pool{New: func() interface{} { return new(Data) }}
// Data is a struct to define some variables to use in custom logger function.
type Data struct {
Pid string

View File

@ -11,6 +11,7 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
"github.com/mattn/go-colorable"
"github.com/mattn/go-isatty"
"github.com/valyala/bytebufferpool"
@ -55,6 +56,8 @@ func New(config ...Config) fiber.Handler {
once sync.Once
mu sync.Mutex
errHandler fiber.ErrorHandler
dataPool = sync.Pool{New: func() interface{} { return new(Data) }}
)
// If colors are enabled, check terminal compatibility
@ -75,7 +78,7 @@ func New(config ...Config) fiber.Handler {
}
// Return new handler
return func(c *fiber.Ctx) (err error) {
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
@ -101,13 +104,13 @@ func New(config ...Config) fiber.Handler {
})
// Logger data
data := DataPool.Get().(*Data)
data := dataPool.Get().(*Data) //nolint:forcetypeassert,errcheck // We store nothing else in the pool
// no need for a reset, as long as we always override everything
data.Pid = pid
data.ErrPaddingStr = errPaddingStr
data.Timestamp = timestamp
// put data back in the pool
defer DataPool.Put(data)
defer dataPool.Put(data)
// Set latency start time
if cfg.enableLatency {
@ -121,7 +124,7 @@ func New(config ...Config) fiber.Handler {
// Manually call error handler
if chainErr != nil {
if err := errHandler(c, chainErr); err != nil {
_ = c.SendStatus(fiber.StatusInternalServerError)
_ = c.SendStatus(fiber.StatusInternalServerError) //nolint:errcheck // TODO: Explain why we ignore the error here
}
}
@ -142,18 +145,20 @@ func New(config ...Config) fiber.Handler {
}
// Format log to buffer
_, _ = buf.WriteString(fmt.Sprintf("%s |%s %3d %s| %7v | %15s |%s %-7s %s| %-"+errPaddingStr+"s %s\n",
timestamp.Load().(string),
statusColor(c.Response().StatusCode(), colors), c.Response().StatusCode(), colors.Reset,
data.Stop.Sub(data.Start).Round(time.Millisecond),
c.IP(),
methodColor(c.Method(), colors), c.Method(), colors.Reset,
c.Path(),
formatErr,
))
_, _ = buf.WriteString( //nolint:errcheck // This will never fail
fmt.Sprintf("%s |%s %3d %s| %7v | %15s |%s %-7s %s| %-"+errPaddingStr+"s %s\n",
timestamp.Load().(string),
statusColor(c.Response().StatusCode(), colors), c.Response().StatusCode(), colors.Reset,
data.Stop.Sub(data.Start).Round(time.Millisecond),
c.IP(),
methodColor(c.Method(), colors), c.Method(), colors.Reset,
c.Path(),
formatErr,
),
)
// Write buffer to output
_, _ = cfg.Output.Write(buf.Bytes())
_, _ = cfg.Output.Write(buf.Bytes()) //nolint:errcheck // This will never fail
if cfg.Done != nil {
cfg.Done(c, buf.Bytes())
@ -169,7 +174,7 @@ func New(config ...Config) fiber.Handler {
// Loop over template parts execute dynamic parts and add fixed parts to the buffer
for i, logFunc := range logFunChain {
if logFunc == nil {
_, _ = buf.Write(templateChain[i])
_, _ = buf.Write(templateChain[i]) //nolint:errcheck // This will never fail
} else if templateChain[i] == nil {
_, err = logFunc(buf, c, data, "")
} else {
@ -182,7 +187,7 @@ func New(config ...Config) fiber.Handler {
// Also write errors to the buffer
if err != nil {
_, _ = buf.WriteString(err.Error())
_, _ = buf.WriteString(err.Error()) //nolint:errcheck // This will never fail
}
mu.Lock()
// Write buffer to output
@ -190,7 +195,7 @@ func New(config ...Config) fiber.Handler {
// Write error to output
if _, err := cfg.Output.Write([]byte(err.Error())); err != nil {
// There is something wrong with the given io.Writer
fmt.Fprintf(os.Stderr, "Failed to write to log, %v\n", err)
_, _ = fmt.Fprintf(os.Stderr, "Failed to write to log, %v\n", err)
}
}
mu.Unlock()

View File

@ -1,3 +1,4 @@
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
package logger
import (
@ -16,6 +17,7 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/requestid"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp"
)
@ -37,7 +39,7 @@ func Test_Logger(t *testing.T) {
return errors.New("some random error")
})
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
utils.AssertEqual(t, "some random error", buf.String())
@ -70,21 +72,21 @@ func Test_Logger_locals(t *testing.T) {
return c.SendStatus(fiber.StatusOK)
})
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "johndoe", buf.String())
buf.Reset()
resp, err = app.Test(httptest.NewRequest("GET", "/int", nil))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/int", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "55", buf.String())
buf.Reset()
resp, err = app.Test(httptest.NewRequest("GET", "/empty", nil))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/empty", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "", buf.String())
@ -100,7 +102,7 @@ func Test_Logger_Next(t *testing.T) {
},
}))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
}
@ -113,15 +115,15 @@ func Test_Logger_Done(t *testing.T) {
app.Use(New(Config{
Done: func(c *fiber.Ctx, logString []byte) {
if c.Response().StatusCode() == fiber.StatusOK {
buf.Write(logString)
_, err := buf.Write(logString)
utils.AssertEqual(t, nil, err)
}
},
})).Get("/logging", func(ctx *fiber.Ctx) error {
return ctx.SendStatus(fiber.StatusOK)
})
resp, err := app.Test(httptest.NewRequest("GET", "/logging", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/logging", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, true, buf.Len() > 0)
@ -135,7 +137,7 @@ func Test_Logger_ErrorTimeZone(t *testing.T) {
TimeZone: "invalid",
}))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
}
@ -156,7 +158,7 @@ func Test_Logger_ErrorOutput(t *testing.T) {
Output: o,
}))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
@ -178,7 +180,7 @@ func Test_Logger_All(t *testing.T) {
// Alias colors
colors := app.Config().ColorScheme
resp, err := app.Test(httptest.NewRequest("GET", "/?foo=bar", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?foo=bar", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
@ -198,7 +200,7 @@ func Test_Query_Params(t *testing.T) {
Output: buf,
}))
resp, err := app.Test(httptest.NewRequest("GET", "/?foo=bar&baz=moz", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?foo=bar&baz=moz", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
@ -226,7 +228,7 @@ func Test_Response_Body(t *testing.T) {
return c.Send([]byte("Post in test"))
})
_, err := app.Test(httptest.NewRequest("GET", "/", nil))
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
expectedGetResponse := "Sample response body"
@ -234,7 +236,7 @@ func Test_Response_Body(t *testing.T) {
buf.Reset() // Reset buffer to test POST
_, err = app.Test(httptest.NewRequest("POST", "/test", nil))
_, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/test", nil))
utils.AssertEqual(t, nil, err)
expectedPostResponse := "Post in test"
@ -258,7 +260,7 @@ func Test_Logger_AppendUint(t *testing.T) {
return c.SendString("hello")
})
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "0 5 200", buf.String())
@ -285,12 +287,11 @@ func Test_Logger_Data_Race(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
resp1, err1 = app.Test(httptest.NewRequest("GET", "/", nil))
resp1, err1 = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
wg.Done()
}()
resp2, err2 = app.Test(httptest.NewRequest("GET", "/", nil))
resp2, err2 = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
wg.Wait()
utils.AssertEqual(t, nil, err1)
utils.AssertEqual(t, fiber.StatusOK, resp1.StatusCode)
utils.AssertEqual(t, nil, err2)
@ -299,21 +300,23 @@ func Test_Logger_Data_Race(t *testing.T) {
// go test -v -run=^$ -bench=Benchmark_Logger -benchmem -count=4
func Benchmark_Logger(b *testing.B) {
benchSetup := func(bb *testing.B, app *fiber.App) {
benchSetup := func(b *testing.B, app *fiber.App) {
b.Helper()
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod("GET")
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/")
bb.ReportAllocs()
bb.ResetTimer()
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < bb.N; n++ {
for n := 0; n < b.N; n++ {
h(fctx)
}
utils.AssertEqual(bb, 200, fctx.Response.Header.StatusCode())
utils.AssertEqual(b, 200, fctx.Response.Header.StatusCode())
}
b.Run("Base", func(bb *testing.B) {
@ -375,8 +378,7 @@ func Test_Response_Header(t *testing.T) {
return c.SendString("Hello fiber!")
})
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "Hello fiber!", buf.String())
@ -396,10 +398,10 @@ func Test_Req_Header(t *testing.T) {
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello fiber!")
})
headerReq := httptest.NewRequest("GET", "/", nil)
headerReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
headerReq.Header.Add("test", "Hello fiber!")
resp, err := app.Test(headerReq)
resp, err := app.Test(headerReq)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "Hello fiber!", buf.String())
@ -419,10 +421,10 @@ func Test_ReqHeader_Header(t *testing.T) {
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello fiber!")
})
reqHeaderReq := httptest.NewRequest("GET", "/", nil)
reqHeaderReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
reqHeaderReq.Header.Add("test", "Hello fiber!")
resp, err := app.Test(reqHeaderReq)
resp, err := app.Test(reqHeaderReq)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "Hello fiber!", buf.String())
@ -449,10 +451,10 @@ func Test_CustomTags(t *testing.T) {
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello fiber!")
})
reqHeaderReq := httptest.NewRequest("GET", "/", nil)
reqHeaderReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
reqHeaderReq.Header.Add("test", "Hello fiber!")
resp, err := app.Test(reqHeaderReq)
resp, err := app.Test(reqHeaderReq)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, customTag, buf.String())
@ -492,7 +494,7 @@ func Test_Logger_ByteSent_Streaming(t *testing.T) {
return nil
})
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "0 0 200", buf.String())

View File

@ -31,7 +31,7 @@ const (
TagBytesReceived = "bytesReceived"
TagRoute = "route"
TagError = "error"
// DEPRECATED: Use TagReqHeader instead
// Deprecated: Use TagReqHeader instead
TagHeader = "header:"
TagReqHeader = "reqHeader:"
TagRespHeader = "respHeader:"
@ -195,7 +195,7 @@ func createTagMap(cfg *Config) map[string]LogFunc {
return output.WriteString(fmt.Sprintf("%7v", latency))
},
TagTime: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(data.Timestamp.Load().(string))
return output.WriteString(data.Timestamp.Load().(string)) //nolint:forcetypeassert // We always store a string in here
},
}
// merge with custom tags from user

View File

@ -14,13 +14,16 @@ import (
// funcChain contains for the parts which exist the functions for the dynamic parts
// funcChain and fixParts always have the same length and contain nil for the parts where no data is required in the chain,
// if a function exists for the part, a parameter for it can also exist in the fixParts slice
func buildLogFuncChain(cfg *Config, tagFunctions map[string]LogFunc) (fixParts [][]byte, funcChain []LogFunc, err error) {
func buildLogFuncChain(cfg *Config, tagFunctions map[string]LogFunc) ([][]byte, []LogFunc, error) {
// process flow is copied from the fasttemplate flow https://github.com/valyala/fasttemplate/blob/2a2d1afadadf9715bfa19683cdaeac8347e5d9f9/template.go#L23-L62
templateB := utils.UnsafeBytes(cfg.Format)
startTagB := utils.UnsafeBytes(startTag)
endTagB := utils.UnsafeBytes(endTag)
paramSeparatorB := utils.UnsafeBytes(paramSeparator)
var fixParts [][]byte
var funcChain []LogFunc
for {
currentPos := bytes.Index(templateB, startTagB)
if currentPos < 0 {
@ -42,13 +45,13 @@ func buildLogFuncChain(cfg *Config, tagFunctions map[string]LogFunc) (fixParts [
// ## function block ##
// first check for tags with parameters
if index := bytes.Index(templateB[:currentPos], paramSeparatorB); index != -1 {
if logFunc, ok := tagFunctions[utils.UnsafeString(templateB[:index+1])]; ok {
funcChain = append(funcChain, logFunc)
// add param to the fixParts
fixParts = append(fixParts, templateB[index+1:currentPos])
} else {
logFunc, ok := tagFunctions[utils.UnsafeString(templateB[:index+1])]
if !ok {
return nil, nil, errors.New("No parameter found in \"" + utils.UnsafeString(templateB[:currentPos]) + "\"")
}
funcChain = append(funcChain, logFunc)
// add param to the fixParts
fixParts = append(fixParts, templateB[index+1:currentPos])
} else if logFunc, ok := tagFunctions[utils.UnsafeString(templateB[:currentPos])]; ok {
// add functions without parameter
funcChain = append(funcChain, logFunc)
@ -63,5 +66,5 @@ func buildLogFuncChain(cfg *Config, tagFunctions map[string]LogFunc) (fixParts [
funcChain = append(funcChain, nil)
fixParts = append(fixParts, templateB)
return
return fixParts, funcChain, nil
}

View File

@ -41,36 +41,48 @@ type Config struct {
// ChartJsURL for specify ChartJS library path or URL . also you can use relative path
//
// Optional. Default: https://cdn.jsdelivr.net/npm/chart.js@2.9/dist/Chart.bundle.min.js
ChartJsURL string
ChartJSURL string
index string
}
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{
Title: defaultTitle,
Refresh: defaultRefresh,
FontURL: defaultFontURL,
ChartJsURL: defaultChartJsURL,
ChartJSURL: defaultChartJSURL,
CustomHead: defaultCustomHead,
APIOnly: false,
Next: nil,
index: newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJsURL,
defaultCustomHead}),
index: newIndex(viewBag{
defaultTitle,
defaultRefresh,
defaultFontURL,
defaultChartJSURL,
defaultCustomHead,
}),
}
func configDefault(config ...Config) Config {
// Users can change ConfigDefault.Title/Refresh which then
// become incompatible with ConfigDefault.index
if ConfigDefault.Title != defaultTitle || ConfigDefault.Refresh != defaultRefresh ||
ConfigDefault.FontURL != defaultFontURL || ConfigDefault.ChartJsURL != defaultChartJsURL ||
if ConfigDefault.Title != defaultTitle ||
ConfigDefault.Refresh != defaultRefresh ||
ConfigDefault.FontURL != defaultFontURL ||
ConfigDefault.ChartJSURL != defaultChartJSURL ||
ConfigDefault.CustomHead != defaultCustomHead {
if ConfigDefault.Refresh < minRefresh {
ConfigDefault.Refresh = minRefresh
}
// update default index with new default title/refresh
ConfigDefault.index = newIndex(viewBag{ConfigDefault.Title,
ConfigDefault.Refresh, ConfigDefault.FontURL, ConfigDefault.ChartJsURL, ConfigDefault.CustomHead})
ConfigDefault.index = newIndex(viewBag{
ConfigDefault.Title,
ConfigDefault.Refresh,
ConfigDefault.FontURL,
ConfigDefault.ChartJSURL,
ConfigDefault.CustomHead,
})
}
// Return default config if nothing provided
@ -93,8 +105,8 @@ func configDefault(config ...Config) Config {
cfg.FontURL = defaultFontURL
}
if cfg.ChartJsURL == "" {
cfg.ChartJsURL = defaultChartJsURL
if cfg.ChartJSURL == "" {
cfg.ChartJSURL = defaultChartJSURL
}
if cfg.Refresh < minRefresh {
cfg.Refresh = minRefresh
@ -112,8 +124,8 @@ func configDefault(config ...Config) Config {
cfg.index = newIndex(viewBag{
title: cfg.Title,
refresh: cfg.Refresh,
fontUrl: cfg.FontURL,
chartJsUrl: cfg.ChartJsURL,
fontURL: cfg.FontURL,
chartJSURL: cfg.ChartJSURL,
customHead: cfg.CustomHead,
})

View File

@ -18,11 +18,11 @@ func Test_Config_Default(t *testing.T) {
utils.AssertEqual(t, defaultTitle, cfg.Title)
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL)
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
utils.AssertEqual(t, false, cfg.APIOnly)
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJsURL, defaultCustomHead}), cfg.index)
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
})
t.Run("set title", func(t *testing.T) {
@ -35,11 +35,11 @@ func Test_Config_Default(t *testing.T) {
utils.AssertEqual(t, title, cfg.Title)
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL)
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
utils.AssertEqual(t, false, cfg.APIOnly)
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
utils.AssertEqual(t, newIndex(viewBag{title, defaultRefresh, defaultFontURL, defaultChartJsURL, defaultCustomHead}), cfg.index)
utils.AssertEqual(t, newIndex(viewBag{title, defaultRefresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
})
t.Run("set refresh less than default", func(t *testing.T) {
@ -51,11 +51,11 @@ func Test_Config_Default(t *testing.T) {
utils.AssertEqual(t, defaultTitle, cfg.Title)
utils.AssertEqual(t, minRefresh, cfg.Refresh)
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL)
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
utils.AssertEqual(t, false, cfg.APIOnly)
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, minRefresh, defaultFontURL, defaultChartJsURL, defaultCustomHead}), cfg.index)
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, minRefresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
})
t.Run("set refresh", func(t *testing.T) {
@ -68,45 +68,45 @@ func Test_Config_Default(t *testing.T) {
utils.AssertEqual(t, defaultTitle, cfg.Title)
utils.AssertEqual(t, refresh, cfg.Refresh)
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL)
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
utils.AssertEqual(t, false, cfg.APIOnly)
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, refresh, defaultFontURL, defaultChartJsURL, defaultCustomHead}), cfg.index)
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, refresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
})
t.Run("set font url", func(t *testing.T) {
t.Parallel()
fontUrl := "https://example.com"
fontURL := "https://example.com"
cfg := configDefault(Config{
FontURL: fontUrl,
FontURL: fontURL,
})
utils.AssertEqual(t, defaultTitle, cfg.Title)
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
utils.AssertEqual(t, fontUrl, cfg.FontURL)
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL)
utils.AssertEqual(t, fontURL, cfg.FontURL)
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
utils.AssertEqual(t, false, cfg.APIOnly)
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, fontUrl, defaultChartJsURL, defaultCustomHead}), cfg.index)
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, fontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
})
t.Run("set chart js url", func(t *testing.T) {
t.Parallel()
chartUrl := "http://example.com"
chartURL := "http://example.com"
cfg := configDefault(Config{
ChartJsURL: chartUrl,
ChartJSURL: chartURL,
})
utils.AssertEqual(t, defaultTitle, cfg.Title)
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
utils.AssertEqual(t, chartUrl, cfg.ChartJsURL)
utils.AssertEqual(t, chartURL, cfg.ChartJSURL)
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
utils.AssertEqual(t, false, cfg.APIOnly)
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, chartUrl, defaultCustomHead}), cfg.index)
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, chartURL, defaultCustomHead}), cfg.index)
})
t.Run("set custom head", func(t *testing.T) {
@ -119,11 +119,11 @@ func Test_Config_Default(t *testing.T) {
utils.AssertEqual(t, defaultTitle, cfg.Title)
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL)
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
utils.AssertEqual(t, head, cfg.CustomHead)
utils.AssertEqual(t, false, cfg.APIOnly)
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJsURL, head}), cfg.index)
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJSURL, head}), cfg.index)
})
t.Run("set api only", func(t *testing.T) {
@ -135,11 +135,11 @@ func Test_Config_Default(t *testing.T) {
utils.AssertEqual(t, defaultTitle, cfg.Title)
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL)
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
utils.AssertEqual(t, true, cfg.APIOnly)
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJsURL, defaultCustomHead}), cfg.index)
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
})
t.Run("set next", func(t *testing.T) {
@ -154,10 +154,10 @@ func Test_Config_Default(t *testing.T) {
utils.AssertEqual(t, defaultTitle, cfg.Title)
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL)
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
utils.AssertEqual(t, false, cfg.APIOnly)
utils.AssertEqual(t, f(nil), cfg.Next(nil))
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJsURL, defaultCustomHead}), cfg.index)
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
})
}

View File

@ -9,23 +9,22 @@ import (
type viewBag struct {
title string
refresh time.Duration
fontUrl string
chartJsUrl string
fontURL string
chartJSURL string
customHead string
}
// returns index with new title/refresh
func newIndex(dat viewBag) string {
timeout := dat.refresh.Milliseconds() - timeoutDiff
if timeout < timeoutDiff {
timeout = timeoutDiff
}
ts := strconv.FormatInt(timeout, 10)
replacer := strings.NewReplacer("$TITLE", dat.title, "$TIMEOUT", ts,
"$FONT_URL", dat.fontUrl, "$CHART_JS_URL", dat.chartJsUrl, "$CUSTOM_HEAD", dat.customHead,
"$FONT_URL", dat.fontURL, "$CHART_JS_URL", dat.chartJSURL, "$CUSTOM_HEAD", dat.customHead,
)
return replacer.Replace(indexHtml)
return replacer.Replace(indexHTML)
}
const (
@ -35,11 +34,11 @@ const (
timeoutDiff = 200 // timeout will be Refresh (in milliseconds) - timeoutDiff
minRefresh = timeoutDiff * time.Millisecond
defaultFontURL = `https://fonts.googleapis.com/css2?family=Roboto:wght@400;900&display=swap`
defaultChartJsURL = `https://cdn.jsdelivr.net/npm/chart.js@2.9/dist/Chart.bundle.min.js`
defaultChartJSURL = `https://cdn.jsdelivr.net/npm/chart.js@2.9/dist/Chart.bundle.min.js`
defaultCustomHead = ``
// parametrized by $TITLE and $TIMEOUT
indexHtml = `<!DOCTYPE html>
indexHTML = `<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">

View File

@ -33,18 +33,20 @@ type statsOS struct {
Conns int `json:"conns"`
}
//nolint:gochecknoglobals // TODO: Do not use a global var here
var (
monitPidCpu atomic.Value
monitPidRam atomic.Value
monitPidConns atomic.Value
monitPIDCPU atomic.Value
monitPIDRAM atomic.Value
monitPIDConns atomic.Value
monitOsCpu atomic.Value
monitOsRam atomic.Value
monitOsTotalRam atomic.Value
monitOsLoadAvg atomic.Value
monitOsConns atomic.Value
monitOSCPU atomic.Value
monitOSRAM atomic.Value
monitOSTotalRAM atomic.Value
monitOSLoadAvg atomic.Value
monitOSConns atomic.Value
)
//nolint:gochecknoglobals // TODO: Do not use a global var here
var (
mutex sync.RWMutex
once sync.Once
@ -58,7 +60,7 @@ func New(config ...Config) fiber.Handler {
// Start routine to update statistics
once.Do(func() {
p, _ := process.NewProcess(int32(os.Getpid()))
p, _ := process.NewProcess(int32(os.Getpid())) //nolint:errcheck // TODO: Handle error
updateStatistics(p)
@ -72,6 +74,7 @@ func New(config ...Config) fiber.Handler {
})
// Return new handler
//nolint:errcheck // Ignore the type-assertion errors
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
@ -83,15 +86,15 @@ func New(config ...Config) fiber.Handler {
}
if c.Get(fiber.HeaderAccept) == fiber.MIMEApplicationJSON || cfg.APIOnly {
mutex.Lock()
data.PID.CPU = monitPidCpu.Load().(float64)
data.PID.RAM = monitPidRam.Load().(uint64)
data.PID.Conns = monitPidConns.Load().(int)
data.PID.CPU, _ = monitPIDCPU.Load().(float64)
data.PID.RAM, _ = monitPIDRAM.Load().(uint64)
data.PID.Conns, _ = monitPIDConns.Load().(int)
data.OS.CPU = monitOsCpu.Load().(float64)
data.OS.RAM = monitOsRam.Load().(uint64)
data.OS.TotalRAM = monitOsTotalRam.Load().(uint64)
data.OS.LoadAvg = monitOsLoadAvg.Load().(float64)
data.OS.Conns = monitOsConns.Load().(int)
data.OS.CPU, _ = monitOSCPU.Load().(float64)
data.OS.RAM, _ = monitOSRAM.Load().(uint64)
data.OS.TotalRAM, _ = monitOSTotalRAM.Load().(uint64)
data.OS.LoadAvg, _ = monitOSLoadAvg.Load().(float64)
data.OS.Conns, _ = monitOSConns.Load().(int)
mutex.Unlock()
return c.Status(fiber.StatusOK).JSON(data)
}
@ -101,29 +104,35 @@ func New(config ...Config) fiber.Handler {
}
func updateStatistics(p *process.Process) {
pidCpu, _ := p.CPUPercent()
monitPidCpu.Store(pidCpu / 10)
if osCpu, _ := cpu.Percent(0, false); len(osCpu) > 0 {
monitOsCpu.Store(osCpu[0])
pidCPU, err := p.CPUPercent()
if err != nil {
monitPIDCPU.Store(pidCPU / 10) //nolint:gomnd // TODO: Explain why we divide by 10 here
}
if pidMem, _ := p.MemoryInfo(); pidMem != nil {
monitPidRam.Store(pidMem.RSS)
if osCPU, err := cpu.Percent(0, false); err != nil && len(osCPU) > 0 {
monitOSCPU.Store(osCPU[0])
}
if osMem, _ := mem.VirtualMemory(); osMem != nil {
monitOsRam.Store(osMem.Used)
monitOsTotalRam.Store(osMem.Total)
if pidRAM, err := p.MemoryInfo(); err != nil && pidRAM != nil {
monitPIDRAM.Store(pidRAM.RSS)
}
if loadAvg, _ := load.Avg(); loadAvg != nil {
monitOsLoadAvg.Store(loadAvg.Load1)
if osRAM, err := mem.VirtualMemory(); err != nil && osRAM != nil {
monitOSRAM.Store(osRAM.Used)
monitOSTotalRAM.Store(osRAM.Total)
}
pidConns, _ := net.ConnectionsPid("tcp", p.Pid)
monitPidConns.Store(len(pidConns))
if loadAvg, err := load.Avg(); err != nil && loadAvg != nil {
monitOSLoadAvg.Store(loadAvg.Load1)
}
osConns, _ := net.Connections("tcp")
monitOsConns.Store(len(osConns))
pidConns, err := net.ConnectionsPid("tcp", p.Pid)
if err != nil {
monitPIDConns.Store(len(pidConns))
}
osConns, err := net.Connections("tcp")
if err != nil {
monitOSConns.Store(len(osConns))
}
}

View File

@ -10,6 +10,7 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
@ -61,6 +62,7 @@ func Test_Monitor_Html(t *testing.T) {
conf.Refresh.Milliseconds()-timeoutDiff)
utils.AssertEqual(t, true, bytes.Contains(buf, []byte(timeoutLine)))
}
func Test_Monitor_Html_CustomCodes(t *testing.T) {
t.Parallel()
@ -82,8 +84,10 @@ func Test_Monitor_Html_CustomCodes(t *testing.T) {
utils.AssertEqual(t, true, bytes.Contains(buf, []byte(timeoutLine)))
// custom config
conf := Config{Title: "New " + defaultTitle, Refresh: defaultRefresh + time.Second,
ChartJsURL: "https://cdnjs.com/libraries/Chart.js",
conf := Config{
Title: "New " + defaultTitle,
Refresh: defaultRefresh + time.Second,
ChartJSURL: "https://cdnjs.com/libraries/Chart.js",
FontURL: "/public/my-font.css",
CustomHead: `<style>body{background:#fff}</style>`,
}
@ -136,7 +140,7 @@ func Benchmark_Monitor(b *testing.B) {
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod("GET")
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/")
fctx.Request.Header.Set(fiber.HeaderAccept, fiber.MIMEApplicationJSON)

View File

@ -1,6 +1,8 @@
package pprof
import "github.com/gofiber/fiber/v2"
import (
"github.com/gofiber/fiber/v2"
)
// Config defines the config for middleware.
type Config struct {
@ -17,6 +19,7 @@ type Config struct {
Prefix string
}
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{
Next: nil,
}

View File

@ -5,22 +5,8 @@ import (
"strings"
"github.com/gofiber/fiber/v2"
"github.com/valyala/fasthttp/fasthttpadaptor"
)
// Set pprof adaptors
var (
pprofIndex = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Index)
pprofCmdline = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Cmdline)
pprofProfile = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Profile)
pprofSymbol = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Symbol)
pprofTrace = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Trace)
pprofAllocs = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("allocs").ServeHTTP)
pprofBlock = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("block").ServeHTTP)
pprofGoroutine = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("goroutine").ServeHTTP)
pprofHeap = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("heap").ServeHTTP)
pprofMutex = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("mutex").ServeHTTP)
pprofThreadcreate = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("threadcreate").ServeHTTP)
"github.com/valyala/fasthttp/fasthttpadaptor"
)
// New creates a new middleware handler
@ -28,6 +14,21 @@ func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Set pprof adaptors
var (
pprofIndex = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Index)
pprofCmdline = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Cmdline)
pprofProfile = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Profile)
pprofSymbol = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Symbol)
pprofTrace = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Trace)
pprofAllocs = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("allocs").ServeHTTP)
pprofBlock = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("block").ServeHTTP)
pprofGoroutine = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("goroutine").ServeHTTP)
pprofHeap = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("heap").ServeHTTP)
pprofMutex = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("mutex").ServeHTTP)
pprofThreadcreate = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("threadcreate").ServeHTTP)
)
// Return new handler
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true

View File

@ -5,6 +5,7 @@ import (
"time"
"github.com/gofiber/fiber/v2"
"github.com/valyala/fasthttp"
)
@ -48,7 +49,7 @@ type Config struct {
WriteBufferSize int
// tls config for the http client.
TlsConfig *tls.Config
TlsConfig *tls.Config //nolint:stylecheck,revive // TODO: Rename to "TLSConfig" in v3
// Client is custom client when client config is complex.
// Note that Servers, Timeout, WriteBufferSize, ReadBufferSize and TlsConfig
@ -57,6 +58,8 @@ type Config struct {
}
// ConfigDefault is the default config
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{
Next: nil,
ModifyRequest: nil,

View File

@ -3,19 +3,20 @@ package proxy
import (
"bytes"
"crypto/tls"
"fmt"
"log"
"net/url"
"strings"
"sync"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
// New is deprecated
func New(config Config) fiber.Handler {
fmt.Println("proxy.New is deprecated, please use proxy.Balancer instead")
log.Printf("proxy.New is deprecated, please use proxy.Balancer instead\n")
return Balancer(config)
}
@ -25,7 +26,7 @@ func Balancer(config Config) fiber.Handler {
cfg := configDefault(config)
// Load balanced client
var lbc = &fasthttp.LBClient{}
lbc := &fasthttp.LBClient{}
// Note that Servers, Timeout, WriteBufferSize, ReadBufferSize and TlsConfig
// will not be used if the client are set.
if config.Client == nil {
@ -61,7 +62,7 @@ func Balancer(config Config) fiber.Handler {
}
// Return new handler
return func(c *fiber.Ctx) (err error) {
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
@ -76,7 +77,7 @@ func Balancer(config Config) fiber.Handler {
// Modify request
if cfg.ModifyRequest != nil {
if err = cfg.ModifyRequest(c); err != nil {
if err := cfg.ModifyRequest(c); err != nil {
return err
}
}
@ -84,7 +85,7 @@ func Balancer(config Config) fiber.Handler {
req.SetRequestURI(utils.UnsafeString(req.RequestURI()))
// Forward request
if err = lbc.Do(req, res); err != nil {
if err := lbc.Do(req, res); err != nil {
return err
}
@ -93,7 +94,7 @@ func Balancer(config Config) fiber.Handler {
// Modify response
if cfg.ModifyResponse != nil {
if err = cfg.ModifyResponse(c); err != nil {
if err := cfg.ModifyResponse(c); err != nil {
return err
}
}
@ -103,16 +104,20 @@ func Balancer(config Config) fiber.Handler {
}
}
//nolint:gochecknoglobals // TODO: Do not use a global var here
var client = &fasthttp.Client{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
}
//nolint:gochecknoglobals // TODO: Do not use a global var here
var lock sync.RWMutex
// WithTlsConfig update http client with a user specified tls.config
// This function should be called before Do and Forward.
// Deprecated: use WithClient instead.
//
//nolint:stylecheck,revive // TODO: Rename to "WithTLSConfig" in v3
func WithTlsConfig(tlsConfig *tls.Config) {
client.TLSConfig = tlsConfig
}

View File

@ -4,7 +4,6 @@ import (
"crypto/tls"
"io"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
@ -13,10 +12,11 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/tlstest"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
func createProxyTestServer(handler fiber.Handler, t *testing.T) (*fiber.App, string) {
func createProxyTestServer(t *testing.T, handler fiber.Handler) (*fiber.App, string) {
t.Helper()
target := fiber.New(fiber.Config{DisableStartupMessage: true})
@ -60,7 +60,7 @@ func Test_Proxy_Next(t *testing.T) {
},
}))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
}
@ -69,11 +69,11 @@ func Test_Proxy_Next(t *testing.T) {
func Test_Proxy(t *testing.T) {
t.Parallel()
target, addr := createProxyTestServer(
func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }, t,
)
target, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusTeapot)
})
resp, err := target.Test(httptest.NewRequest("GET", "/", nil), 2000)
resp, err := target.Test(httptest.NewRequest(fiber.MethodGet, "/", nil), 2000)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
@ -81,7 +81,7 @@ func Test_Proxy(t *testing.T) {
app.Use(Balancer(Config{Servers: []string{addr}}))
req := httptest.NewRequest("GET", "/", nil)
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Host = addr
resp, err = app.Test(req)
utils.AssertEqual(t, nil, err)
@ -107,7 +107,7 @@ func Test_Proxy_Balancer_WithTlsConfig(t *testing.T) {
})
addr := ln.Addr().String()
clientTLSConf := &tls.Config{InsecureSkipVerify: true}
clientTLSConf := &tls.Config{InsecureSkipVerify: true} //nolint:gosec // We're in a test func, so this is fine
// disable certificate verification in Balancer
app.Use(Balancer(Config{
@ -128,9 +128,9 @@ func Test_Proxy_Balancer_WithTlsConfig(t *testing.T) {
func Test_Proxy_Forward_WithTlsConfig_To_Http(t *testing.T) {
t.Parallel()
_, targetAddr := createProxyTestServer(func(c *fiber.Ctx) error {
_, targetAddr := createProxyTestServer(t, func(c *fiber.Ctx) error {
return c.SendString("hello from target")
}, t)
})
proxyServerTLSConf, _, err := tlstest.GetTLSConfigs()
utils.AssertEqual(t, nil, err)
@ -164,13 +164,13 @@ func Test_Proxy_Forward(t *testing.T) {
app := fiber.New()
_, addr := createProxyTestServer(
func(c *fiber.Ctx) error { return c.SendString("forwarded") }, t,
)
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
return c.SendString("forwarded")
})
app.Use(Forward("http://" + addr))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
@ -198,7 +198,7 @@ func Test_Proxy_Forward_WithTlsConfig(t *testing.T) {
})
addr := ln.Addr().String()
clientTLSConf := &tls.Config{InsecureSkipVerify: true}
clientTLSConf := &tls.Config{InsecureSkipVerify: true} //nolint:gosec // We're in a test func, so this is fine
// disable certificate verification
WithTlsConfig(clientTLSConf)
@ -217,9 +217,9 @@ func Test_Proxy_Forward_WithTlsConfig(t *testing.T) {
func Test_Proxy_Modify_Response(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServer(func(c *fiber.Ctx) error {
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
return c.Status(500).SendString("not modified")
}, t)
})
app := fiber.New()
app.Use(Balancer(Config{
@ -230,7 +230,7 @@ func Test_Proxy_Modify_Response(t *testing.T) {
},
}))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
@ -243,10 +243,10 @@ func Test_Proxy_Modify_Response(t *testing.T) {
func Test_Proxy_Modify_Request(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServer(func(c *fiber.Ctx) error {
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
b := c.Request().Body()
return c.SendString(string(b))
}, t)
})
app := fiber.New()
app.Use(Balancer(Config{
@ -257,7 +257,7 @@ func Test_Proxy_Modify_Request(t *testing.T) {
},
}))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
@ -270,10 +270,10 @@ func Test_Proxy_Modify_Request(t *testing.T) {
func Test_Proxy_Timeout_Slow_Server(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServer(func(c *fiber.Ctx) error {
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
time.Sleep(2 * time.Second)
return c.SendString("fiber is awesome")
}, t)
})
app := fiber.New()
app.Use(Balancer(Config{
@ -281,7 +281,7 @@ func Test_Proxy_Timeout_Slow_Server(t *testing.T) {
Timeout: 3 * time.Second,
}))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil), 5000)
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil), 5000)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
@ -294,10 +294,10 @@ func Test_Proxy_Timeout_Slow_Server(t *testing.T) {
func Test_Proxy_With_Timeout(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServer(func(c *fiber.Ctx) error {
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
time.Sleep(1 * time.Second)
return c.SendString("fiber is awesome")
}, t)
})
app := fiber.New()
app.Use(Balancer(Config{
@ -305,7 +305,7 @@ func Test_Proxy_With_Timeout(t *testing.T) {
Timeout: 100 * time.Millisecond,
}))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil), 2000)
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil), 2000)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
@ -318,16 +318,16 @@ func Test_Proxy_With_Timeout(t *testing.T) {
func Test_Proxy_Buffer_Size_Response(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServer(func(c *fiber.Ctx) error {
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
long := strings.Join(make([]string, 5000), "-")
c.Set("Very-Long-Header", long)
return c.SendString("ok")
}, t)
})
app := fiber.New()
app.Use(Balancer(Config{Servers: []string{addr}}))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
@ -337,7 +337,7 @@ func Test_Proxy_Buffer_Size_Response(t *testing.T) {
ReadBufferSize: 1024 * 8,
}))
resp, err = app.Test(httptest.NewRequest("GET", "/", nil))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
}
@ -357,9 +357,9 @@ func Test_Proxy_Do_RestoreOriginalURL(t *testing.T) {
utils.AssertEqual(t, originalURL, c.OriginalURL())
return c.SendString("ok")
})
_, err1 := app.Test(httptest.NewRequest("GET", "/test", nil))
_, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
// This test requires multiple requests due to zero allocation used in fiber
_, err2 := app.Test(httptest.NewRequest("GET", "/test", nil))
_, err2 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, nil, err1)
utils.AssertEqual(t, nil, err2)
@ -369,9 +369,9 @@ func Test_Proxy_Do_RestoreOriginalURL(t *testing.T) {
func Test_Proxy_Do_HTTP_Prefix_URL(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServer(func(c *fiber.Ctx) error {
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
return c.SendString("hello world")
}, t)
})
app := fiber.New(fiber.Config{DisableStartupMessage: true})
app.Get("/*", func(c *fiber.Ctx) error {
@ -386,7 +386,7 @@ func Test_Proxy_Do_HTTP_Prefix_URL(t *testing.T) {
return nil
})
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/http://"+addr, nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/http://"+addr, nil))
utils.AssertEqual(t, nil, err)
s, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
@ -431,9 +431,8 @@ func Test_Proxy_Forward_Local_Client(t *testing.T) {
app.Use(Forward("http://"+addr+"/test_local_client", &fasthttp.Client{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
Dial: func(addr string) (net.Conn, error) {
return fasthttp.Dial(addr)
},
Dial: fasthttp.Dial,
}))
go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }()
@ -447,11 +446,11 @@ func Test_Proxy_Forward_Local_Client(t *testing.T) {
func Test_ProxyBalancer_Custom_Client(t *testing.T) {
t.Parallel()
target, addr := createProxyTestServer(
func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }, t,
)
target, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusTeapot)
})
resp, err := target.Test(httptest.NewRequest("GET", "/", nil), 2000)
resp, err := target.Test(httptest.NewRequest(fiber.MethodGet, "/", nil), 2000)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
@ -468,7 +467,7 @@ func Test_ProxyBalancer_Custom_Client(t *testing.T) {
Timeout: time.Second,
}}))
req := httptest.NewRequest("GET", "/", nil)
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Host = addr
resp, err = app.Test(req)
utils.AssertEqual(t, nil, err)

View File

@ -1,4 +1,4 @@
package recover
package recover //nolint:predeclared // TODO: Rename to some non-builtin
import (
"github.com/gofiber/fiber/v2"
@ -23,6 +23,8 @@ type Config struct {
}
// ConfigDefault is the default config
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{
Next: nil,
EnableStackTrace: false,

View File

@ -1,4 +1,4 @@
package recover
package recover //nolint:predeclared // TODO: Rename to some non-builtin
import (
"fmt"
@ -9,7 +9,7 @@ import (
)
func defaultStackTraceHandler(_ *fiber.Ctx, e interface{}) {
_, _ = os.Stderr.WriteString(fmt.Sprintf("panic: %v\n%s\n", e, debug.Stack()))
_, _ = os.Stderr.WriteString(fmt.Sprintf("panic: %v\n%s\n", e, debug.Stack())) //nolint:errcheck // This will never fail
}
// New creates a new middleware handler
@ -18,7 +18,7 @@ func New(config ...Config) fiber.Handler {
cfg := configDefault(config...)
// Return new handler
return func(c *fiber.Ctx) (err error) {
return func(c *fiber.Ctx) (err error) { //nolint:nonamedreturns // Uses recover() to overwrite the error
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()

View File

@ -1,4 +1,4 @@
package recover
package recover //nolint:predeclared // TODO: Rename to some non-builtin
import (
"net/http/httptest"
@ -24,7 +24,7 @@ func Test_Recover(t *testing.T) {
panic("Hi, I'm an error!")
})
resp, err := app.Test(httptest.NewRequest("GET", "/panic", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/panic", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
}
@ -39,7 +39,7 @@ func Test_Recover_Next(t *testing.T) {
},
}))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
}
@ -55,7 +55,7 @@ func Test_Recover_EnableStackTrace(t *testing.T) {
panic("Hi, I'm an error!")
})
resp, err := app.Test(httptest.NewRequest("GET", "/panic", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/panic", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
}

View File

@ -33,6 +33,8 @@ type Config struct {
// It uses a fast UUID generator which will expose the number of
// requests made to the server. To conceal this value for better
// privacy, use the "utils.UUIDv4" generator.
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{
Next: nil,
Header: fiber.HeaderXRequestID,

View File

@ -19,14 +19,14 @@ func Test_RequestID(t *testing.T) {
return c.SendString("Hello, World 👋!")
})
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
reqid := resp.Header.Get(fiber.HeaderXRequestID)
utils.AssertEqual(t, 36, len(reqid))
req := httptest.NewRequest("GET", "/", nil)
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Add(fiber.HeaderXRequestID, reqid)
resp, err = app.Test(req)
@ -45,7 +45,7 @@ func Test_RequestID_Next(t *testing.T) {
},
}))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, resp.Header.Get(fiber.HeaderXRequestID), "")
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
@ -54,13 +54,13 @@ func Test_RequestID_Next(t *testing.T) {
// go test -run Test_RequestID_Locals
func Test_RequestID_Locals(t *testing.T) {
t.Parallel()
reqId := "ThisIsARequestId"
reqID := "ThisIsARequestId"
ctxKey := "ThisIsAContextKey"
app := fiber.New()
app.Use(New(Config{
Generator: func() string {
return reqId
return reqID
},
ContextKey: ctxKey,
}))
@ -68,11 +68,11 @@ func Test_RequestID_Locals(t *testing.T) {
var ctxVal string
app.Use(func(c *fiber.Ctx) error {
ctxVal = c.Locals(ctxKey).(string)
ctxVal = c.Locals(ctxKey).(string) //nolint:forcetypeassert,errcheck // We always store a string in here
return c.Next()
})
_, err := app.Test(httptest.NewRequest("GET", "/", nil))
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, reqId, ctxVal)
utils.AssertEqual(t, reqID, ctxVal)
}

View File

@ -32,7 +32,7 @@ func (s *Session) Save() error
func (s *Session) Fresh() bool
func (s *Session) ID() string
func (s *Session) Keys() []string
func (s *Session) SetExpiry(time.Duration)
func (s *Session) SetExpiry(time.Duration)
```
**⚠ _Storing `interface{}` values are limited to built-ins Go types_**
@ -148,7 +148,7 @@ type Config struct {
// Optional. Default value utils.UUID
KeyGenerator func() string
// Deprecated, please use KeyLookup
// Deprecated: Please use KeyLookup
CookieName string
// Source defines where to obtain the session id

View File

@ -2,6 +2,7 @@ package session
import (
"fmt"
"log"
"strings"
"time"
@ -49,7 +50,7 @@ type Config struct {
// Optional. Default value utils.UUIDv4
KeyGenerator func() string
// Deprecated, please use KeyLookup
// Deprecated: Please use KeyLookup
CookieName string
// Source defines where to obtain the session id
@ -68,8 +69,10 @@ const (
)
// ConfigDefault is the default config
//
//nolint:gochecknoglobals // Using a global var is fine here
var ConfigDefault = Config{
Expiration: 24 * time.Hour,
Expiration: 24 * time.Hour, //nolint:gomnd // No magic number, just the default config
KeyLookup: "cookie:session_id",
KeyGenerator: utils.UUIDv4,
source: "cookie",
@ -91,7 +94,7 @@ func configDefault(config ...Config) Config {
cfg.Expiration = ConfigDefault.Expiration
}
if cfg.CookieName != "" {
fmt.Println("[session] CookieName is deprecated, please use KeyLookup")
log.Printf("[session] CookieName is deprecated, please use KeyLookup\n")
cfg.KeyLookup = fmt.Sprintf("cookie:%s", cfg.CookieName)
}
if cfg.KeyLookup == "" {
@ -102,7 +105,8 @@ func configDefault(config ...Config) Config {
}
selectors := strings.Split(cfg.KeyLookup, ":")
if len(selectors) != 2 {
const numSelectors = 2
if len(selectors) != numSelectors {
panic("[session] KeyLookup must in the form of <source>:<name>")
}
switch Source(selectors[0]) {

View File

@ -13,6 +13,7 @@ type data struct {
Data map[string]interface{}
}
//nolint:gochecknoglobals // TODO: Do not use a global var here
var dataPool = sync.Pool{
New: func() interface{} {
d := new(data)
@ -22,7 +23,7 @@ var dataPool = sync.Pool{
}
func acquireData() *data {
return dataPool.Get().(*data)
return dataPool.Get().(*data) //nolint:forcetypeassert // We store nothing else in the pool
}
func (d *data) Reset() {

View File

@ -3,11 +3,13 @@ package session
import (
"bytes"
"encoding/gob"
"fmt"
"sync"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
@ -21,6 +23,7 @@ type Session struct {
exp time.Duration // expiration of this session
}
//nolint:gochecknoglobals // TODO: Do not use a global var here
var sessionPool = sync.Pool{
New: func() interface{} {
return new(Session)
@ -28,7 +31,7 @@ var sessionPool = sync.Pool{
}
func acquireSession() *Session {
s := sessionPool.Get().(*Session)
s := sessionPool.Get().(*Session) //nolint:forcetypeassert,errcheck // We store nothing else in the pool
if s.data == nil {
s.data = acquireData()
}
@ -153,7 +156,7 @@ func (s *Session) Save() error {
encCache := gob.NewEncoder(s.byteBuffer)
err := encCache.Encode(&s.data.Data)
if err != nil {
return err
return fmt.Errorf("failed to encode data: %w", err)
}
// copy the data in buffer

View File

@ -7,6 +7,7 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/storage/memory"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
@ -94,6 +95,8 @@ func Test_Session(t *testing.T) {
}
// go test -run Test_Session_Types
//
//nolint:forcetypeassert // TODO: Do not force-type assert
func Test_Session_Types(t *testing.T) {
t.Parallel()
@ -127,25 +130,27 @@ func Test_Session_Types(t *testing.T) {
Name: "John",
}
// set value
var vbool = true
var vstring = "str"
var vint = 13
var vint8 int8 = 13
var vint16 int16 = 13
var vint32 int32 = 13
var vint64 int64 = 13
var vuint uint = 13
var vuint8 uint8 = 13
var vuint16 uint16 = 13
var vuint32 uint32 = 13
var vuint64 uint64 = 13
var vuintptr uintptr = 13
var vbyte byte = 'k'
var vrune rune = 'k'
var vfloat32 float32 = 13
var vfloat64 float64 = 13
var vcomplex64 complex64 = 13
var vcomplex128 complex128 = 13
var (
vbool = true
vstring = "str"
vint = 13
vint8 int8 = 13
vint16 int16 = 13
vint32 int32 = 13
vint64 int64 = 13
vuint uint = 13
vuint8 uint8 = 13
vuint16 uint16 = 13
vuint32 uint32 = 13
vuint64 uint64 = 13
vuintptr uintptr = 13
vbyte byte = 'k'
vrune = 'k'
vfloat32 float32 = 13
vfloat64 float64 = 13
vcomplex64 complex64 = 13
vcomplex128 complex128 = 13
)
sess.Set("vuser", vuser)
sess.Set("vbool", vbool)
sess.Set("vstring", vstring)
@ -212,7 +217,8 @@ func Test_Session_Store_Reset(t *testing.T) {
defer app.ReleaseCtx(ctx)
// get session
sess, _ := store.Get(ctx)
sess, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
// make sure its new
utils.AssertEqual(t, true, sess.Fresh())
// set value & save
@ -224,7 +230,8 @@ func Test_Session_Store_Reset(t *testing.T) {
utils.AssertEqual(t, nil, store.Reset())
// make sure the session is recreated
sess, _ = store.Get(ctx)
sess, err = store.Get(ctx)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, sess.Fresh())
utils.AssertEqual(t, nil, sess.Get("hello"))
}
@ -242,12 +249,13 @@ func Test_Session_Save(t *testing.T) {
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
sess, _ := store.Get(ctx)
sess, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
// set value
sess.Set("name", "john")
// save session
err := sess.Save()
err = sess.Save()
utils.AssertEqual(t, nil, err)
})
@ -262,12 +270,13 @@ func Test_Session_Save(t *testing.T) {
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
sess, _ := store.Get(ctx)
sess, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
// set value
sess.Set("name", "john")
// save session
err := sess.Save()
err = sess.Save()
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, store.getSessionID(ctx), string(ctx.Response().Header.Peek(store.sessionName)))
utils.AssertEqual(t, store.getSessionID(ctx), string(ctx.Request().Header.Peek(store.sessionName)))
@ -287,7 +296,8 @@ func Test_Session_Save_Expiration(t *testing.T) {
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
sess, _ := store.Get(ctx)
sess, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
// set value
sess.Set("name", "john")
@ -295,18 +305,20 @@ func Test_Session_Save_Expiration(t *testing.T) {
sess.SetExpiry(time.Second * 5)
// save session
err := sess.Save()
err = sess.Save()
utils.AssertEqual(t, nil, err)
// here you need to get the old session yet
sess, _ = store.Get(ctx)
sess, err = store.Get(ctx)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "john", sess.Get("name"))
// just to make sure the session has been expired
time.Sleep(time.Second * 5)
// here you should get a new session
sess, _ = store.Get(ctx)
sess, err = store.Get(ctx)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, nil, sess.Get("name"))
})
}
@ -325,7 +337,8 @@ func Test_Session_Reset(t *testing.T) {
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
sess, _ := store.Get(ctx)
sess, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
sess.Set("name", "fenny")
utils.AssertEqual(t, nil, sess.Destroy())
@ -345,14 +358,16 @@ func Test_Session_Reset(t *testing.T) {
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
sess, _ := store.Get(ctx)
sess, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
// set value & save
sess.Set("name", "fenny")
utils.AssertEqual(t, nil, sess.Save())
sess, _ = store.Get(ctx)
sess, err = store.Get(ctx)
utils.AssertEqual(t, nil, err)
err := sess.Destroy()
err = sess.Destroy()
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "", string(ctx.Response().Header.Peek(store.sessionName)))
utils.AssertEqual(t, "", string(ctx.Request().Header.Peek(store.sessionName)))
@ -383,7 +398,8 @@ func Test_Session_Cookie(t *testing.T) {
defer app.ReleaseCtx(ctx)
// get session
sess, _ := store.Get(ctx)
sess, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, nil, sess.Save())
// cookie should be set on Save ( even if empty data )
@ -401,12 +417,14 @@ func Test_Session_Cookie_In_Response(t *testing.T) {
defer app.ReleaseCtx(ctx)
// get session
sess, _ := store.Get(ctx)
sess, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
sess.Set("id", "1")
utils.AssertEqual(t, true, sess.Fresh())
utils.AssertEqual(t, nil, sess.Save())
sess, _ = store.Get(ctx)
sess, err = store.Get(ctx)
utils.AssertEqual(t, nil, err)
sess.Set("name", "john")
utils.AssertEqual(t, true, sess.Fresh())
@ -497,7 +515,7 @@ func Benchmark_Session(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
sess, _ := store.Get(c)
sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
sess.Set("john", "doe")
err = sess.Save()
}
@ -512,7 +530,7 @@ func Benchmark_Session(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
sess, _ := store.Get(c)
sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
sess.Set("john", "doe")
err = sess.Save()
}

View File

@ -2,11 +2,13 @@ package session
import (
"encoding/gob"
"fmt"
"sync"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/storage/memory"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
@ -14,6 +16,7 @@ type Store struct {
Config
}
//nolint:gochecknoglobals // TODO: Do not use a global var here
var mux sync.Mutex
func New(config ...Config) *Store {
@ -31,7 +34,7 @@ func New(config ...Config) *Store {
// RegisterType will allow you to encode/decode custom types
// into any Storage provider
func (s *Store) RegisterType(i interface{}) {
func (*Store) RegisterType(i interface{}) {
gob.Register(i)
}
@ -70,11 +73,11 @@ func (s *Store) Get(c *fiber.Ctx) (*Session, error) {
if raw != nil && err == nil {
mux.Lock()
defer mux.Unlock()
_, _ = sess.byteBuffer.Write(raw)
_, _ = sess.byteBuffer.Write(raw) //nolint:errcheck // This will never fail
encCache := gob.NewDecoder(sess.byteBuffer)
err := encCache.Decode(&sess.data.Data)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to decode session data: %w", err)
}
} else if err != nil {
return nil, err

View File

@ -6,6 +6,7 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)

View File

@ -1,6 +1,8 @@
package skip
import "github.com/gofiber/fiber/v2"
import (
"github.com/gofiber/fiber/v2"
)
// New creates a middleware handler which skips the wrapped handler
// if the exclude predicate returns true.

View File

@ -17,7 +17,7 @@ func Test_Skip(t *testing.T) {
app.Use(skip.New(errTeapotHandler, func(*fiber.Ctx) bool { return true }))
app.Get("/", helloWorldHandler)
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
}
@ -30,7 +30,7 @@ func Test_SkipFalse(t *testing.T) {
app.Use(skip.New(errTeapotHandler, func(*fiber.Ctx) bool { return false }))
app.Get("/", helloWorldHandler)
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
}
@ -43,7 +43,7 @@ func Test_SkipNilFunc(t *testing.T) {
app.Use(skip.New(errTeapotHandler, nil))
app.Get("/", helloWorldHandler)
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
}

View File

@ -18,7 +18,8 @@ func Test_Timeout(t *testing.T) {
// fiber instance
app := fiber.New()
h := New(func(c *fiber.Ctx) error {
sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms")
sleepTime, err := time.ParseDuration(c.Params("sleepTime") + "ms")
utils.AssertEqual(t, nil, err)
if err := sleepWithContext(c.UserContext(), sleepTime, context.DeadlineExceeded); err != nil {
return fmt.Errorf("%w: l2 wrap", fmt.Errorf("%w: l1 wrap ", err))
}
@ -26,12 +27,12 @@ func Test_Timeout(t *testing.T) {
}, 100*time.Millisecond)
app.Get("/test/:sleepTime", h)
testTimeout := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
}
testSucces := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
}
@ -49,7 +50,8 @@ func Test_TimeoutWithCustomError(t *testing.T) {
// fiber instance
app := fiber.New()
h := New(func(c *fiber.Ctx) error {
sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms")
sleepTime, err := time.ParseDuration(c.Params("sleepTime") + "ms")
utils.AssertEqual(t, nil, err)
if err := sleepWithContext(c.UserContext(), sleepTime, ErrFooTimeOut); err != nil {
return fmt.Errorf("%w: execution error", err)
}
@ -57,12 +59,12 @@ func Test_TimeoutWithCustomError(t *testing.T) {
}, 100*time.Millisecond, ErrFooTimeOut)
app.Get("/test/:sleepTime", h)
testTimeout := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
}
testSucces := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
}

Some files were not shown because too many files have changed in this diff Show More