mirror of https://github.com/gofiber/fiber.git
🚀 Feature: Add and apply more stricter golangci-lint linting rules (#2286)
* golangci-lint: add and apply more stricter linting rules * github: drop security workflow now that we use gosec linter inside golangci-lint * github: use official golangci-lint CI linter * Add editorconfig and gitattributes filepull/2313/head
parent
7327a17951
commit
167a8b5e94
|
@ -0,0 +1,8 @@
|
|||
; This file is for unifying the coding style for different editors and IDEs.
|
||||
; More information at http://editorconfig.org
|
||||
; This style originates from https://github.com/fewagency/best-practices
|
||||
root = true
|
||||
|
||||
[*]
|
||||
charset = utf-8
|
||||
end_of_line = lf
|
|
@ -0,0 +1,12 @@
|
|||
# Handle line endings automatically for files detected as text
|
||||
# and leave all files detected as binary untouched.
|
||||
* text=auto eol=lf
|
||||
|
||||
# Force batch scripts to always use CRLF line endings so that if a repo is accessed
|
||||
# in Windows via a file share from Linux, the scripts will work.
|
||||
*.{cmd,[cC][mM][dD]} text eol=crlf
|
||||
*.{bat,[bB][aA][tT]} text eol=crlf
|
||||
|
||||
# Force bash scripts to always use LF line endings so that if a repo is accessed
|
||||
# in Unix via a file share from Windows, the scripts will work.
|
||||
*.sh text eol=lf
|
|
@ -1,17 +1,28 @@
|
|||
# Adapted from https://github.com/golangci/golangci-lint-action/blob/b56f6f529003f1c81d4d759be6bd5f10bf9a0fa0/README.md#how-to-use
|
||||
|
||||
name: golangci-lint
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- main
|
||||
pull_request:
|
||||
name: Linter
|
||||
push:
|
||||
tags:
|
||||
- v*
|
||||
branches:
|
||||
- master
|
||||
- main
|
||||
pull_request:
|
||||
permissions:
|
||||
contents: read
|
||||
jobs:
|
||||
Golint:
|
||||
golangci:
|
||||
name: lint
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Fetch Repository
|
||||
uses: actions/checkout@v3
|
||||
- name: Run Golint
|
||||
uses: reviewdog/action-golangci-lint@v2
|
||||
- uses: actions/setup-go@v3
|
||||
with:
|
||||
golangci_lint_flags: "--tests=false"
|
||||
# NOTE: Keep this in sync with the version from go.mod
|
||||
go-version: 1.19
|
||||
- uses: actions/checkout@v3
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v3
|
||||
with:
|
||||
# NOTE: Keep this in sync with the version from .golangci.yml
|
||||
version: v1.50.1
|
||||
|
|
|
@ -1,17 +0,0 @@
|
|||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- main
|
||||
pull_request:
|
||||
name: Security
|
||||
jobs:
|
||||
Gosec:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Fetch Repository
|
||||
uses: actions/checkout@v3
|
||||
- name: Run Gosec
|
||||
uses: securego/gosec@master
|
||||
with:
|
||||
args: -exclude-dir=internal/*/ ./...
|
|
@ -0,0 +1,258 @@
|
|||
# Created based on v1.50.1
|
||||
# NOTE: Keep this in sync with the version in .github/workflows/linter.yml
|
||||
|
||||
run:
|
||||
modules-download-mode: readonly
|
||||
skip-dirs-use-default: false
|
||||
skip-dirs:
|
||||
- internal # TODO: Also apply proper linting for internal dir
|
||||
|
||||
output:
|
||||
sort-results: true
|
||||
|
||||
linters-settings:
|
||||
# TODO: Eventually enable these checks
|
||||
# depguard:
|
||||
# include-go-root: true
|
||||
# packages:
|
||||
# - flag
|
||||
# - io/ioutil
|
||||
# - reflect
|
||||
# - unsafe
|
||||
# packages-with-error-message:
|
||||
# - flag: '`flag` package is only allowed in main.go'
|
||||
# - io/ioutil: '`io/ioutil` package is deprecated, use the `io` and `os` package instead'
|
||||
# - reflect: '`reflect` package is dangerous to use'
|
||||
# - unsafe: '`unsafe` package is dangerous to use'
|
||||
|
||||
errcheck:
|
||||
check-type-assertions: true
|
||||
check-blank: true
|
||||
disable-default-exclusions: true
|
||||
|
||||
errchkjson:
|
||||
report-no-exported: true
|
||||
|
||||
exhaustive:
|
||||
default-signifies-exhaustive: true
|
||||
|
||||
forbidigo:
|
||||
forbid:
|
||||
- ^(fmt\.Print(|f|ln)|print|println)$
|
||||
- 'http\.Default(Client|Transport)'
|
||||
# TODO: Eventually enable these patterns
|
||||
# - 'time\.Sleep'
|
||||
# - 'panic'
|
||||
|
||||
gci:
|
||||
sections:
|
||||
- standard
|
||||
- prefix(github.com/gofiber/fiber)
|
||||
- default
|
||||
- blank
|
||||
- dot
|
||||
custom-order: true
|
||||
|
||||
gocritic:
|
||||
disabled-checks:
|
||||
- ifElseChain
|
||||
|
||||
gofumpt:
|
||||
module-path: github.com/gofiber/fiber
|
||||
extra-rules: true
|
||||
|
||||
gosec:
|
||||
config:
|
||||
global:
|
||||
audit: true
|
||||
|
||||
govet:
|
||||
check-shadowing: true
|
||||
enable-all: true
|
||||
disable:
|
||||
- shadow
|
||||
- fieldalignment
|
||||
|
||||
grouper:
|
||||
import-require-single-import: true
|
||||
import-require-grouping: true
|
||||
|
||||
misspell:
|
||||
locale: US
|
||||
|
||||
nolintlint:
|
||||
require-explanation: true
|
||||
require-specific: true
|
||||
|
||||
nonamedreturns:
|
||||
report-error-in-defer: true
|
||||
|
||||
predeclared:
|
||||
q: true
|
||||
|
||||
promlinter:
|
||||
strict: true
|
||||
|
||||
revive:
|
||||
enable-all-rules: true
|
||||
rules:
|
||||
# Provided by gomnd linter
|
||||
- name: add-constant
|
||||
disabled: true
|
||||
- name: argument-limit
|
||||
disabled: true
|
||||
# Provided by bidichk
|
||||
- name: banned-characters
|
||||
disabled: true
|
||||
- name: cognitive-complexity
|
||||
disabled: true
|
||||
- name: cyclomatic
|
||||
disabled: true
|
||||
- name: exported
|
||||
disabled: true
|
||||
- name: file-header
|
||||
disabled: true
|
||||
- name: function-result-limit
|
||||
disabled: true
|
||||
- name: function-length
|
||||
disabled: true
|
||||
- name: line-length-limit
|
||||
disabled: true
|
||||
- name: max-public-structs
|
||||
disabled: true
|
||||
- name: modifies-parameter
|
||||
disabled: true
|
||||
- name: nested-structs
|
||||
disabled: true
|
||||
- name: package-comments
|
||||
disabled: true
|
||||
|
||||
stylecheck:
|
||||
checks:
|
||||
- all
|
||||
- -ST1000
|
||||
- -ST1020
|
||||
- -ST1021
|
||||
- -ST1022
|
||||
|
||||
tagliatelle:
|
||||
case:
|
||||
rules:
|
||||
json: snake
|
||||
|
||||
#tenv:
|
||||
# all: true
|
||||
|
||||
#unparam:
|
||||
# check-exported: true
|
||||
|
||||
wrapcheck:
|
||||
ignorePackageGlobs:
|
||||
- github.com/gofiber/fiber/*
|
||||
- github.com/valyala/fasthttp
|
||||
|
||||
issues:
|
||||
exclude-use-default: false
|
||||
|
||||
linters:
|
||||
enable:
|
||||
- asasalint
|
||||
- asciicheck
|
||||
- bidichk
|
||||
- bodyclose
|
||||
- containedctx
|
||||
- contextcheck
|
||||
# - cyclop
|
||||
# - deadcode
|
||||
# - decorder
|
||||
- depguard
|
||||
- dogsled
|
||||
# - dupl
|
||||
# - dupword
|
||||
- durationcheck
|
||||
- errcheck
|
||||
- errchkjson
|
||||
- errname
|
||||
- errorlint
|
||||
- execinquery
|
||||
- exhaustive
|
||||
# - exhaustivestruct
|
||||
# - exhaustruct
|
||||
- exportloopref
|
||||
- forbidigo
|
||||
- forcetypeassert
|
||||
# - funlen
|
||||
- gci
|
||||
- gochecknoglobals
|
||||
- gochecknoinits
|
||||
# - gocognit
|
||||
- goconst
|
||||
- gocritic
|
||||
# - gocyclo
|
||||
# - godot
|
||||
# - godox
|
||||
# - goerr113
|
||||
- gofmt
|
||||
- gofumpt
|
||||
# - goheader
|
||||
- goimports
|
||||
# - golint
|
||||
- gomnd
|
||||
- gomoddirectives
|
||||
# - gomodguard
|
||||
- goprintffuncname
|
||||
- gosec
|
||||
- gosimple
|
||||
- govet
|
||||
- grouper
|
||||
# - ifshort
|
||||
# - importas
|
||||
- ineffassign
|
||||
# - interfacebloat
|
||||
# - interfacer
|
||||
# - ireturn
|
||||
# - lll
|
||||
- loggercheck
|
||||
# - maintidx
|
||||
# - makezero
|
||||
# - maligned
|
||||
- misspell
|
||||
- nakedret
|
||||
# - nestif
|
||||
- nilerr
|
||||
- nilnil
|
||||
# - nlreturn
|
||||
- noctx
|
||||
- nolintlint
|
||||
- nonamedreturns
|
||||
# - nosnakecase
|
||||
- nosprintfhostport
|
||||
# - paralleltest # TODO: Enable once https://github.com/gofiber/fiber/issues/2254 is implemented
|
||||
# - prealloc
|
||||
- predeclared
|
||||
- promlinter
|
||||
- reassign
|
||||
- revive
|
||||
- rowserrcheck
|
||||
# - scopelint
|
||||
- sqlclosecheck
|
||||
- staticcheck
|
||||
# - structcheck
|
||||
- stylecheck
|
||||
- tagliatelle
|
||||
# - tenv # TODO: Enable once we drop support for Go 1.16
|
||||
# - testableexamples
|
||||
# - testpackage # TODO: Enable once https://github.com/gofiber/fiber/issues/2252 is implemented
|
||||
- thelper
|
||||
# - tparallel # TODO: Enable once https://github.com/gofiber/fiber/issues/2254 is implemented
|
||||
- typecheck
|
||||
- unconvert
|
||||
- unparam
|
||||
- unused
|
||||
- usestdlibvars
|
||||
# - varcheck
|
||||
# - varnamelen
|
||||
- wastedassign
|
||||
- whitespace
|
||||
- wrapcheck
|
||||
# - wsl
|
58
app.go
58
app.go
|
@ -14,6 +14,7 @@ import (
|
|||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
|
@ -24,6 +25,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
|
@ -306,7 +308,7 @@ type Config struct {
|
|||
|
||||
// FEATURE: v2.3.x
|
||||
// The router executes the same handler by default if StrictRouting or CaseSensitive is disabled.
|
||||
// Enabling RedirectFixedPath will change this behaviour into a client redirect to the original route path.
|
||||
// Enabling RedirectFixedPath will change this behavior into a client redirect to the original route path.
|
||||
// Using the status code 301 for GET requests and 308 for all other request methods.
|
||||
//
|
||||
// Default: false
|
||||
|
@ -454,6 +456,8 @@ const (
|
|||
)
|
||||
|
||||
// HTTP methods enabled by default
|
||||
//
|
||||
//nolint:gochecknoglobals // Using a global var is fine here
|
||||
var DefaultMethods = []string{
|
||||
MethodGet,
|
||||
MethodHead,
|
||||
|
@ -467,7 +471,7 @@ var DefaultMethods = []string{
|
|||
}
|
||||
|
||||
// DefaultErrorHandler that process return errors from handlers
|
||||
var DefaultErrorHandler = func(c *Ctx, err error) error {
|
||||
func DefaultErrorHandler(c *Ctx, err error) error {
|
||||
code := StatusInternalServerError
|
||||
var e *Error
|
||||
if errors.As(err, &e) {
|
||||
|
@ -519,7 +523,7 @@ func New(config ...Config) *App {
|
|||
|
||||
if app.config.ETag {
|
||||
if !IsChild() {
|
||||
fmt.Println("[Warning] Config.ETag is deprecated since v2.0.6, please use 'middleware/etag'.")
|
||||
log.Printf("[Warning] Config.ETag is deprecated since v2.0.6, please use 'middleware/etag'.\n")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -587,7 +591,7 @@ func (app *App) handleTrustedProxy(ipAddress string) {
|
|||
if strings.Contains(ipAddress, "/") {
|
||||
_, ipNet, err := net.ParseCIDR(ipAddress)
|
||||
if err != nil {
|
||||
fmt.Printf("[Warning] IP range %q could not be parsed: %v\n", ipAddress, err)
|
||||
log.Printf("[Warning] IP range %q could not be parsed: %v\n", ipAddress, err)
|
||||
} else {
|
||||
app.config.trustedProxyRanges = append(app.config.trustedProxyRanges, ipNet)
|
||||
}
|
||||
|
@ -822,7 +826,7 @@ func (app *App) Config() Config {
|
|||
}
|
||||
|
||||
// Handler returns the server handler.
|
||||
func (app *App) Handler() fasthttp.RequestHandler {
|
||||
func (app *App) Handler() fasthttp.RequestHandler { //revive:disable-line:confusing-naming // Having both a Handler() (uppercase) and a handler() (lowercase) is fine. TODO: Use nolint:revive directive instead. See https://github.com/golangci/golangci-lint/issues/3476
|
||||
// prepare the server for the start
|
||||
app.startupProcess()
|
||||
return app.handler
|
||||
|
@ -887,7 +891,7 @@ func (app *App) Hooks() *Hooks {
|
|||
|
||||
// Test is used for internal debugging by passing a *http.Request.
|
||||
// Timeout is optional and defaults to 1s, -1 will disable it completely.
|
||||
func (app *App) Test(req *http.Request, msTimeout ...int) (resp *http.Response, err error) {
|
||||
func (app *App) Test(req *http.Request, msTimeout ...int) (*http.Response, error) {
|
||||
// Set timeout
|
||||
timeout := 1000
|
||||
if len(msTimeout) > 0 {
|
||||
|
@ -902,15 +906,15 @@ func (app *App) Test(req *http.Request, msTimeout ...int) (resp *http.Response,
|
|||
// Dump raw http request
|
||||
dump, err := httputil.DumpRequest(req, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to dump request: %w", err)
|
||||
}
|
||||
|
||||
// Create test connection
|
||||
conn := new(testConn)
|
||||
|
||||
// Write raw http request
|
||||
if _, err = conn.r.Write(dump); err != nil {
|
||||
return nil, err
|
||||
if _, err := conn.r.Write(dump); err != nil {
|
||||
return nil, fmt.Errorf("failed to write: %w", err)
|
||||
}
|
||||
// prepare the server for the start
|
||||
app.startupProcess()
|
||||
|
@ -943,7 +947,7 @@ func (app *App) Test(req *http.Request, msTimeout ...int) (resp *http.Response,
|
|||
}
|
||||
|
||||
// Check for errors
|
||||
if err != nil && err != fasthttp.ErrGetOnly {
|
||||
if err != nil && !errors.Is(err, fasthttp.ErrGetOnly) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -951,12 +955,17 @@ func (app *App) Test(req *http.Request, msTimeout ...int) (resp *http.Response,
|
|||
buffer := bufio.NewReader(&conn.w)
|
||||
|
||||
// Convert raw http response to *http.Response
|
||||
return http.ReadResponse(buffer, req)
|
||||
res, err := http.ReadResponse(buffer, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
type disableLogger struct{}
|
||||
|
||||
func (dl *disableLogger) Printf(_ string, _ ...interface{}) {
|
||||
func (*disableLogger) Printf(_ string, _ ...interface{}) {
|
||||
// fmt.Println(fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
|
@ -967,7 +976,7 @@ func (app *App) init() *App {
|
|||
// Only load templates if a view engine is specified
|
||||
if app.config.Views != nil {
|
||||
if err := app.config.Views.Load(); err != nil {
|
||||
fmt.Printf("views: %v\n", err)
|
||||
log.Printf("[Warning]: failed to load views: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1039,25 +1048,30 @@ func (app *App) ErrorHandler(ctx *Ctx, err error) error {
|
|||
// errors before calling the application's error handler method.
|
||||
func (app *App) serverErrorHandler(fctx *fasthttp.RequestCtx, err error) {
|
||||
c := app.AcquireCtx(fctx)
|
||||
if _, ok := err.(*fasthttp.ErrSmallBuffer); ok {
|
||||
defer app.ReleaseCtx(c)
|
||||
|
||||
var errNetOP *net.OpError
|
||||
|
||||
switch {
|
||||
case errors.As(err, new(*fasthttp.ErrSmallBuffer)):
|
||||
err = ErrRequestHeaderFieldsTooLarge
|
||||
} else if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() {
|
||||
case errors.As(err, &errNetOP) && errNetOP.Timeout():
|
||||
err = ErrRequestTimeout
|
||||
} else if err == fasthttp.ErrBodyTooLarge {
|
||||
case errors.Is(err, fasthttp.ErrBodyTooLarge):
|
||||
err = ErrRequestEntityTooLarge
|
||||
} else if err == fasthttp.ErrGetOnly {
|
||||
case errors.Is(err, fasthttp.ErrGetOnly):
|
||||
err = ErrMethodNotAllowed
|
||||
} else if strings.Contains(err.Error(), "timeout") {
|
||||
case strings.Contains(err.Error(), "timeout"):
|
||||
err = ErrRequestTimeout
|
||||
} else {
|
||||
default:
|
||||
err = NewError(StatusBadRequest, err.Error())
|
||||
}
|
||||
|
||||
if catch := app.ErrorHandler(c, err); catch != nil {
|
||||
_ = c.SendStatus(StatusInternalServerError)
|
||||
log.Printf("serverErrorHandler: failed to call ErrorHandler: %v\n", catch)
|
||||
_ = c.SendStatus(StatusInternalServerError) //nolint:errcheck // It is fine to ignore the error here
|
||||
return
|
||||
}
|
||||
|
||||
app.ReleaseCtx(c)
|
||||
}
|
||||
|
||||
// startupProcess Is the method which executes all the necessary processes just before the start of the server.
|
||||
|
|
65
app_test.go
65
app_test.go
|
@ -2,6 +2,7 @@
|
|||
// 🤖 Github Repository: https://github.com/gofiber/fiber
|
||||
// 📌 API Documentation: https://docs.gofiber.io
|
||||
|
||||
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
|
||||
package fiber
|
||||
|
||||
import (
|
||||
|
@ -23,15 +24,16 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"github.com/valyala/fasthttp/fasthttputil"
|
||||
)
|
||||
|
||||
var testEmptyHandler = func(c *Ctx) error {
|
||||
func testEmptyHandler(_ *Ctx) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func testStatus200(t *testing.T, app *App, url string, method string) {
|
||||
func testStatus200(t *testing.T, app *App, url, method string) {
|
||||
t.Helper()
|
||||
|
||||
req := httptest.NewRequest(method, url, nil)
|
||||
|
@ -42,6 +44,8 @@ func testStatus200(t *testing.T, app *App, url string, method string) {
|
|||
}
|
||||
|
||||
func testErrorResponse(t *testing.T, err error, resp *http.Response, expectedBodyError string) {
|
||||
t.Helper()
|
||||
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, 500, resp.StatusCode, "Status code")
|
||||
|
||||
|
@ -140,7 +144,6 @@ func Test_App_ServerErrorHandler_SmallReadBuffer(t *testing.T) {
|
|||
logHeaderSlice := make([]string, 5000)
|
||||
request.Header.Set("Very-Long-Header", strings.Join(logHeaderSlice, "-"))
|
||||
_, err := app.Test(request)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expect an error at app.Test(request)")
|
||||
}
|
||||
|
@ -470,7 +473,6 @@ func Test_App_Use_MultiplePrefix(t *testing.T) {
|
|||
body, err = io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "/test/doe", string(body))
|
||||
|
||||
}
|
||||
|
||||
func Test_App_Use_StrictRouting(t *testing.T) {
|
||||
|
@ -515,7 +517,7 @@ func Test_App_Add_Method_Test(t *testing.T) {
|
|||
}
|
||||
}()
|
||||
|
||||
methods := append(DefaultMethods, "JOHN")
|
||||
methods := append(DefaultMethods, "JOHN") //nolint:gocritic // We want a new slice here
|
||||
app := New(Config{
|
||||
RequestMethods: methods,
|
||||
})
|
||||
|
@ -780,7 +782,7 @@ func Test_App_ShutdownWithTimeout(t *testing.T) {
|
|||
case <-timer.C:
|
||||
t.Fatal("idle connections not closed on shutdown")
|
||||
case err := <-shutdownErr:
|
||||
if err == nil || err != context.DeadlineExceeded {
|
||||
if err == nil || !errors.Is(err, context.DeadlineExceeded) {
|
||||
t.Fatalf("unexpected err %v. Expecting %v", err, context.DeadlineExceeded)
|
||||
}
|
||||
}
|
||||
|
@ -852,7 +854,7 @@ func Test_App_Static_MaxAge(t *testing.T) {
|
|||
|
||||
app.Static("/", "./.github", Static{MaxAge: 100})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/index.html", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(MethodGet, "/index.html", nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
|
||||
utils.AssertEqual(t, false, resp.Header.Get(HeaderContentLength) == "")
|
||||
|
@ -866,19 +868,19 @@ func Test_App_Static_Custom_CacheControl(t *testing.T) {
|
|||
app := New()
|
||||
|
||||
app.Static("/", "./.github", Static{ModifyResponse: func(c *Ctx) error {
|
||||
if strings.Contains(string(c.GetRespHeader("Content-Type")), "text/html") {
|
||||
if strings.Contains(c.GetRespHeader("Content-Type"), "text/html") {
|
||||
c.Response().Header.Set("Cache-Control", "no-cache, no-store, must-revalidate")
|
||||
}
|
||||
return nil
|
||||
}})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/index.html", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(MethodGet, "/index.html", nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, "no-cache, no-store, must-revalidate", resp.Header.Get(HeaderCacheControl), "CacheControl Control")
|
||||
|
||||
normal_resp, normal_err := app.Test(httptest.NewRequest("GET", "/config.yml", nil))
|
||||
utils.AssertEqual(t, nil, normal_err, "app.Test(req)")
|
||||
utils.AssertEqual(t, "", normal_resp.Header.Get(HeaderCacheControl), "CacheControl Control")
|
||||
respNormal, errNormal := app.Test(httptest.NewRequest(MethodGet, "/config.yml", nil))
|
||||
utils.AssertEqual(t, nil, errNormal, "app.Test(req)")
|
||||
utils.AssertEqual(t, "", respNormal.Header.Get(HeaderCacheControl), "CacheControl Control")
|
||||
}
|
||||
|
||||
// go test -run Test_App_Static_Download
|
||||
|
@ -890,7 +892,7 @@ func Test_App_Static_Download(t *testing.T) {
|
|||
|
||||
app.Static("/fiber.png", "./.github/testdata/fs/img/fiber.png", Static{Download: true})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/fiber.png", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(MethodGet, "/fiber.png", nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
|
||||
utils.AssertEqual(t, false, resp.Header.Get(HeaderContentLength) == "")
|
||||
|
@ -1067,7 +1069,7 @@ func Test_App_Static_Next(t *testing.T) {
|
|||
|
||||
t.Run("app.Static is skipped: invoking Get handler", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(MethodGet, "/", nil)
|
||||
req.Header.Set("X-Custom-Header", "skip")
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
@ -1082,7 +1084,7 @@ func Test_App_Static_Next(t *testing.T) {
|
|||
|
||||
t.Run("app.Static is not skipped: serving index.html", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(MethodGet, "/", nil)
|
||||
req.Header.Set("X-Custom-Header", "don't skip")
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
@ -1379,7 +1381,7 @@ func Test_Test_DumpError(t *testing.T) {
|
|||
|
||||
resp, err := app.Test(httptest.NewRequest(MethodGet, "/", errorReader(0)))
|
||||
utils.AssertEqual(t, true, resp == nil)
|
||||
utils.AssertEqual(t, "errorReader", err.Error())
|
||||
utils.AssertEqual(t, "failed to dump request: errorReader", err.Error())
|
||||
}
|
||||
|
||||
// go test -run Test_App_Handler
|
||||
|
@ -1393,7 +1395,7 @@ type invalidView struct{}
|
|||
|
||||
func (invalidView) Load() error { return errors.New("invalid view") }
|
||||
|
||||
func (i invalidView) Render(io.Writer, string, interface{}, ...string) error { panic("implement me") }
|
||||
func (invalidView) Render(io.Writer, string, interface{}, ...string) error { panic("implement me") }
|
||||
|
||||
// go test -run Test_App_Init_Error_View
|
||||
func Test_App_Init_Error_View(t *testing.T) {
|
||||
|
@ -1405,7 +1407,9 @@ func Test_App_Init_Error_View(t *testing.T) {
|
|||
utils.AssertEqual(t, "implement me", fmt.Sprintf("%v", err))
|
||||
}
|
||||
}()
|
||||
_ = app.config.Views.Render(nil, "", nil)
|
||||
|
||||
err := app.config.Views.Render(nil, "", nil)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
}
|
||||
|
||||
// go test -run Test_App_Stack
|
||||
|
@ -1535,11 +1539,12 @@ func Test_App_SmallReadBuffer(t *testing.T) {
|
|||
|
||||
go func() {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
resp, err := http.Get("http://127.0.0.1:4006/small-read-buffer")
|
||||
if resp != nil {
|
||||
utils.AssertEqual(t, 431, resp.StatusCode)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(context.Background(), MethodGet, "http://127.0.0.1:4006/small-read-buffer", http.NoBody)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
var client http.Client
|
||||
resp, err := client.Do(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 431, resp.StatusCode)
|
||||
utils.AssertEqual(t, nil, app.Shutdown())
|
||||
}()
|
||||
|
||||
|
@ -1572,13 +1577,13 @@ func Test_App_New_Test_Parallel(t *testing.T) {
|
|||
t.Run("Test_App_New_Test_Parallel_1", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := New(Config{Immutable: true})
|
||||
_, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
_, err := app.Test(httptest.NewRequest(MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
})
|
||||
t.Run("Test_App_New_Test_Parallel_2", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := New(Config{Immutable: true})
|
||||
_, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
_, err := app.Test(httptest.NewRequest(MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
})
|
||||
}
|
||||
|
@ -1591,7 +1596,7 @@ func Test_App_ReadBodyStream(t *testing.T) {
|
|||
return c.SendString(fmt.Sprintf("%v %s", c.Request().IsBodyStream(), c.Body()))
|
||||
})
|
||||
testString := "this is a test"
|
||||
resp, err := app.Test(httptest.NewRequest("POST", "/", bytes.NewBufferString(testString)))
|
||||
resp, err := app.Test(httptest.NewRequest(MethodPost, "/", bytes.NewBufferString(testString)))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err, "io.ReadAll(resp.Body)")
|
||||
|
@ -1615,12 +1620,12 @@ func Test_App_DisablePreParseMultipartForm(t *testing.T) {
|
|||
}
|
||||
file, err := mpf.File["test"][0].Open()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to open: %w", err)
|
||||
}
|
||||
buffer := make([]byte, len(testString))
|
||||
n, err := file.Read(buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to read: %w", err)
|
||||
}
|
||||
if n != len(testString) {
|
||||
return fmt.Errorf("bad read length")
|
||||
|
@ -1636,7 +1641,7 @@ func Test_App_DisablePreParseMultipartForm(t *testing.T) {
|
|||
utils.AssertEqual(t, len(testString), n, "writer n")
|
||||
utils.AssertEqual(t, nil, w.Close(), "w.Close()")
|
||||
|
||||
req := httptest.NewRequest("POST", "/", b)
|
||||
req := httptest.NewRequest(MethodPost, "/", b)
|
||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
|
@ -1659,7 +1664,7 @@ func Test_App_Test_no_timeout_infinitely(t *testing.T) {
|
|||
return nil
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
|
||||
req := httptest.NewRequest(MethodGet, "/", http.NoBody)
|
||||
_, err = app.Test(req, -1)
|
||||
}()
|
||||
|
||||
|
@ -1696,7 +1701,7 @@ func Test_App_SetTLSHandler(t *testing.T) {
|
|||
|
||||
func Test_App_AddCustomRequestMethod(t *testing.T) {
|
||||
t.Parallel()
|
||||
methods := append(DefaultMethods, "TEST")
|
||||
methods := append(DefaultMethods, "TEST") //nolint:gocritic // We want a new slice here
|
||||
app := New(Config{
|
||||
RequestMethods: methods,
|
||||
})
|
||||
|
|
83
client.go
83
client.go
|
@ -15,6 +15,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
|
@ -50,6 +51,7 @@ type Args = fasthttp.Args
|
|||
// Copy from fasthttp
|
||||
type RetryIfFunc = fasthttp.RetryIfFunc
|
||||
|
||||
//nolint:gochecknoglobals // TODO: Do not use a global var here
|
||||
var defaultClient Client
|
||||
|
||||
// Client implements http client.
|
||||
|
@ -186,11 +188,11 @@ func (a *Agent) Parse() error {
|
|||
|
||||
uri := a.req.URI()
|
||||
|
||||
isTLS := false
|
||||
var isTLS bool
|
||||
scheme := uri.Scheme()
|
||||
if bytes.Equal(scheme, strHTTPS) {
|
||||
if bytes.Equal(scheme, []byte(schemeHTTPS)) {
|
||||
isTLS = true
|
||||
} else if !bytes.Equal(scheme, strHTTP) {
|
||||
} else if !bytes.Equal(scheme, []byte(schemeHTTP)) {
|
||||
return fmt.Errorf("unsupported protocol %q. http and https are supported", scheme)
|
||||
}
|
||||
|
||||
|
@ -241,7 +243,7 @@ func (a *Agent) SetBytesV(k string, v []byte) *Agent {
|
|||
// SetBytesKV sets the given 'key: value' header.
|
||||
//
|
||||
// Use AddBytesKV for setting multiple header values under the same key.
|
||||
func (a *Agent) SetBytesKV(k []byte, v []byte) *Agent {
|
||||
func (a *Agent) SetBytesKV(k, v []byte) *Agent {
|
||||
a.req.Header.SetBytesKV(k, v)
|
||||
|
||||
return a
|
||||
|
@ -281,7 +283,7 @@ func (a *Agent) AddBytesV(k string, v []byte) *Agent {
|
|||
//
|
||||
// Multiple headers with the same key may be added with this function.
|
||||
// Use SetBytesKV for setting a single header for the given key.
|
||||
func (a *Agent) AddBytesKV(k []byte, v []byte) *Agent {
|
||||
func (a *Agent) AddBytesKV(k, v []byte) *Agent {
|
||||
a.req.Header.AddBytesKV(k, v)
|
||||
|
||||
return a
|
||||
|
@ -652,10 +654,8 @@ func (a *Agent) Reuse() *Agent {
|
|||
// certificate chain and host name.
|
||||
func (a *Agent) InsecureSkipVerify() *Agent {
|
||||
if a.HostClient.TLSConfig == nil {
|
||||
/* #nosec G402 */
|
||||
a.HostClient.TLSConfig = &tls.Config{InsecureSkipVerify: true} // #nosec G402
|
||||
a.HostClient.TLSConfig = &tls.Config{InsecureSkipVerify: true} //nolint:gosec // We explicitly let the user set insecure mode here
|
||||
} else {
|
||||
/* #nosec G402 */
|
||||
a.HostClient.TLSConfig.InsecureSkipVerify = true
|
||||
}
|
||||
|
||||
|
@ -728,14 +728,14 @@ func (a *Agent) RetryIf(retryIf RetryIfFunc) *Agent {
|
|||
// Bytes returns the status code, bytes body and errors of url.
|
||||
//
|
||||
// it's not safe to use Agent after calling [Agent.Bytes]
|
||||
func (a *Agent) Bytes() (code int, body []byte, errs []error) {
|
||||
func (a *Agent) Bytes() (int, []byte, []error) {
|
||||
defer a.release()
|
||||
return a.bytes()
|
||||
}
|
||||
|
||||
func (a *Agent) bytes() (code int, body []byte, errs []error) {
|
||||
func (a *Agent) bytes() (code int, body []byte, errs []error) { //nolint:nonamedreturns,revive // We want to overwrite the body in a deferred func. TODO: Check if we really need to do this. We eventually want to get rid of all named returns.
|
||||
if errs = append(errs, a.errs...); len(errs) > 0 {
|
||||
return
|
||||
return code, body, errs
|
||||
}
|
||||
|
||||
var (
|
||||
|
@ -760,7 +760,7 @@ func (a *Agent) bytes() (code int, body []byte, errs []error) {
|
|||
code = resp.StatusCode()
|
||||
}
|
||||
|
||||
body = append(a.dest, resp.Body()...)
|
||||
body = append(a.dest, resp.Body()...) //nolint:gocritic // We want to append to the returned slice here
|
||||
|
||||
if nilResp {
|
||||
ReleaseResponse(resp)
|
||||
|
@ -770,25 +770,25 @@ func (a *Agent) bytes() (code int, body []byte, errs []error) {
|
|||
if a.timeout > 0 {
|
||||
if err := a.HostClient.DoTimeout(req, resp, a.timeout); err != nil {
|
||||
errs = append(errs, err)
|
||||
return
|
||||
return code, body, errs
|
||||
}
|
||||
} else if a.maxRedirectsCount > 0 && (string(req.Header.Method()) == MethodGet || string(req.Header.Method()) == MethodHead) {
|
||||
if err := a.HostClient.DoRedirects(req, resp, a.maxRedirectsCount); err != nil {
|
||||
errs = append(errs, err)
|
||||
return
|
||||
return code, body, errs
|
||||
}
|
||||
} else if err := a.HostClient.Do(req, resp); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
return
|
||||
return code, body, errs
|
||||
}
|
||||
|
||||
func printDebugInfo(req *Request, resp *Response, w io.Writer) {
|
||||
msg := fmt.Sprintf("Connected to %s(%s)\r\n\r\n", req.URI().Host(), resp.RemoteAddr())
|
||||
_, _ = w.Write(utils.UnsafeBytes(msg))
|
||||
_, _ = req.WriteTo(w)
|
||||
_, _ = resp.WriteTo(w)
|
||||
_, _ = w.Write(utils.UnsafeBytes(msg)) //nolint:errcheck // This will never fail
|
||||
_, _ = req.WriteTo(w) //nolint:errcheck // This will never fail
|
||||
_, _ = resp.WriteTo(w) //nolint:errcheck // This will never fail
|
||||
}
|
||||
|
||||
// String returns the status code, string body and errors of url.
|
||||
|
@ -797,6 +797,7 @@ func printDebugInfo(req *Request, resp *Response, w io.Writer) {
|
|||
func (a *Agent) String() (int, string, []error) {
|
||||
defer a.release()
|
||||
code, body, errs := a.bytes()
|
||||
// TODO: There might be a data race here on body. Maybe use utils.CopyBytes on it?
|
||||
|
||||
return code, utils.UnsafeString(body), errs
|
||||
}
|
||||
|
@ -805,12 +806,15 @@ func (a *Agent) String() (int, string, []error) {
|
|||
// And bytes body will be unmarshalled to given v.
|
||||
//
|
||||
// it's not safe to use Agent after calling [Agent.Struct]
|
||||
func (a *Agent) Struct(v interface{}) (code int, body []byte, errs []error) {
|
||||
func (a *Agent) Struct(v interface{}) (int, []byte, []error) {
|
||||
defer a.release()
|
||||
if code, body, errs = a.bytes(); len(errs) > 0 {
|
||||
return
|
||||
|
||||
code, body, errs := a.bytes()
|
||||
if len(errs) > 0 {
|
||||
return code, body, errs
|
||||
}
|
||||
|
||||
// TODO: This should only be done once
|
||||
if a.jsonDecoder == nil {
|
||||
a.jsonDecoder = json.Unmarshal
|
||||
}
|
||||
|
@ -819,7 +823,7 @@ func (a *Agent) Struct(v interface{}) (code int, body []byte, errs []error) {
|
|||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
return
|
||||
return code, body, errs
|
||||
}
|
||||
|
||||
func (a *Agent) release() {
|
||||
|
@ -855,6 +859,7 @@ func (a *Agent) reset() {
|
|||
a.formFiles = a.formFiles[:0]
|
||||
}
|
||||
|
||||
//nolint:gochecknoglobals // TODO: Do not use global vars here
|
||||
var (
|
||||
clientPool sync.Pool
|
||||
agentPool = sync.Pool{
|
||||
|
@ -877,7 +882,11 @@ func AcquireClient() *Client {
|
|||
if v == nil {
|
||||
return &Client{}
|
||||
}
|
||||
return v.(*Client)
|
||||
c, ok := v.(*Client)
|
||||
if !ok {
|
||||
panic(fmt.Errorf("failed to type-assert to *Client"))
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// ReleaseClient returns c acquired via AcquireClient to client pool.
|
||||
|
@ -899,7 +908,11 @@ func ReleaseClient(c *Client) {
|
|||
// no longer needed. This allows Agent recycling, reduces GC pressure
|
||||
// and usually improves performance.
|
||||
func AcquireAgent() *Agent {
|
||||
return agentPool.Get().(*Agent)
|
||||
a, ok := agentPool.Get().(*Agent)
|
||||
if !ok {
|
||||
panic(fmt.Errorf("failed to type-assert to *Agent"))
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
// ReleaseAgent returns a acquired via AcquireAgent to Agent pool.
|
||||
|
@ -922,7 +935,11 @@ func AcquireResponse() *Response {
|
|||
if v == nil {
|
||||
return &Response{}
|
||||
}
|
||||
return v.(*Response)
|
||||
r, ok := v.(*Response)
|
||||
if !ok {
|
||||
panic(fmt.Errorf("failed to type-assert to *Response"))
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// ReleaseResponse return resp acquired via AcquireResponse to response pool.
|
||||
|
@ -945,7 +962,11 @@ func AcquireArgs() *Args {
|
|||
if v == nil {
|
||||
return &Args{}
|
||||
}
|
||||
return v.(*Args)
|
||||
a, ok := v.(*Args)
|
||||
if !ok {
|
||||
panic(fmt.Errorf("failed to type-assert to *Args"))
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
// ReleaseArgs returns the object acquired via AcquireArgs to the pool.
|
||||
|
@ -966,7 +987,11 @@ func AcquireFormFile() *FormFile {
|
|||
if v == nil {
|
||||
return &FormFile{}
|
||||
}
|
||||
return v.(*FormFile)
|
||||
ff, ok := v.(*FormFile)
|
||||
if !ok {
|
||||
panic(fmt.Errorf("failed to type-assert to *FormFile"))
|
||||
}
|
||||
return ff
|
||||
}
|
||||
|
||||
// ReleaseFormFile returns the object acquired via AcquireFormFile to the pool.
|
||||
|
@ -981,9 +1006,7 @@ func ReleaseFormFile(ff *FormFile) {
|
|||
formFilePool.Put(ff)
|
||||
}
|
||||
|
||||
var (
|
||||
strHTTP = []byte("http")
|
||||
strHTTPS = []byte("https")
|
||||
const (
|
||||
defaultUserAgent = "fiber"
|
||||
)
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
//nolint:wrapcheck // We must not wrap errors in tests
|
||||
package fiber
|
||||
|
||||
import (
|
||||
|
@ -19,6 +20,7 @@ import (
|
|||
|
||||
"github.com/gofiber/fiber/v2/internal/tlstest"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp/fasthttputil"
|
||||
)
|
||||
|
||||
|
@ -295,8 +297,10 @@ func Test_Client_Agent_Set_Or_Add_Headers(t *testing.T) {
|
|||
handler := func(c *Ctx) error {
|
||||
c.Request().Header.VisitAll(func(key, value []byte) {
|
||||
if k := string(key); k == "K1" || k == "K2" {
|
||||
_, _ = c.Write(key)
|
||||
_, _ = c.Write(value)
|
||||
_, err := c.Write(key)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
_, err = c.Write(value)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
}
|
||||
})
|
||||
return nil
|
||||
|
@ -581,25 +585,26 @@ type readErrorConn struct {
|
|||
net.Conn
|
||||
}
|
||||
|
||||
func (r *readErrorConn) Read(p []byte) (int, error) {
|
||||
func (*readErrorConn) Read(_ []byte) (int, error) {
|
||||
return 0, fmt.Errorf("error")
|
||||
}
|
||||
|
||||
func (r *readErrorConn) Write(p []byte) (int, error) {
|
||||
func (*readErrorConn) Write(p []byte) (int, error) {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (r *readErrorConn) Close() error {
|
||||
func (*readErrorConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *readErrorConn) LocalAddr() net.Addr {
|
||||
func (*readErrorConn) LocalAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *readErrorConn) RemoteAddr() net.Addr {
|
||||
func (*readErrorConn) RemoteAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func Test_Client_Agent_RetryIf(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -783,7 +788,10 @@ func Test_Client_Agent_MultipartForm_SendFiles(t *testing.T) {
|
|||
buf := make([]byte, fh1.Size)
|
||||
f, err := fh1.Open()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
defer func() { _ = f.Close() }()
|
||||
defer func() {
|
||||
err := f.Close()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
}()
|
||||
_, err = f.Read(buf)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "form file", string(buf))
|
||||
|
@ -831,13 +839,16 @@ func checkFormFile(t *testing.T, fh *multipart.FileHeader, filename string) {
|
|||
basename := filepath.Base(filename)
|
||||
utils.AssertEqual(t, fh.Filename, basename)
|
||||
|
||||
b1, err := os.ReadFile(filename)
|
||||
b1, err := os.ReadFile(filename) //nolint:gosec // We're in a test so reading user-provided files by name is fine
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
b2 := make([]byte, fh.Size)
|
||||
f, err := fh.Open()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
defer func() { _ = f.Close() }()
|
||||
defer func() {
|
||||
err := f.Close()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
}()
|
||||
_, err = f.Read(b2)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, b1, b2)
|
||||
|
@ -962,6 +973,7 @@ func Test_Client_Agent_InsecureSkipVerify(t *testing.T) {
|
|||
cer, err := tls.LoadX509KeyPair("./.github/testdata/ssl.pem", "./.github/testdata/ssl.key")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
//nolint:gosec // We're in a test so using old ciphers is fine
|
||||
serverTLSConf := &tls.Config{
|
||||
Certificates: []tls.Certificate{cer},
|
||||
}
|
||||
|
@ -1137,7 +1149,7 @@ func Test_Client_Agent_Struct(t *testing.T) {
|
|||
defer ReleaseAgent(a)
|
||||
defer a.ConnectionClose()
|
||||
request := a.Request()
|
||||
request.Header.SetMethod("GET")
|
||||
request.Header.SetMethod(MethodGet)
|
||||
request.SetRequestURI("http://example.com")
|
||||
err := a.Parse()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
@ -1198,8 +1210,8 @@ type errorMultipartWriter struct {
|
|||
count int
|
||||
}
|
||||
|
||||
func (e *errorMultipartWriter) Boundary() string { return "myBoundary" }
|
||||
func (e *errorMultipartWriter) SetBoundary(_ string) error { return nil }
|
||||
func (*errorMultipartWriter) Boundary() string { return "myBoundary" }
|
||||
func (*errorMultipartWriter) SetBoundary(_ string) error { return nil }
|
||||
func (e *errorMultipartWriter) CreateFormFile(_, _ string) (io.Writer, error) {
|
||||
if e.count == 0 {
|
||||
e.count++
|
||||
|
@ -1207,8 +1219,8 @@ func (e *errorMultipartWriter) CreateFormFile(_, _ string) (io.Writer, error) {
|
|||
}
|
||||
return errorWriter{}, nil
|
||||
}
|
||||
func (e *errorMultipartWriter) WriteField(_, _ string) error { return errors.New("WriteField error") }
|
||||
func (e *errorMultipartWriter) Close() error { return errors.New("Close error") }
|
||||
func (*errorMultipartWriter) WriteField(_, _ string) error { return errors.New("WriteField error") }
|
||||
func (*errorMultipartWriter) Close() error { return errors.New("Close error") }
|
||||
|
||||
type errorWriter struct{}
|
||||
|
||||
|
|
2
color.go
2
color.go
|
@ -53,6 +53,8 @@ type Colors struct {
|
|||
}
|
||||
|
||||
// DefaultColors Default color codes
|
||||
//
|
||||
//nolint:gochecknoglobals // Using a global var is fine here
|
||||
var DefaultColors = Colors{
|
||||
Black: "\u001b[90m",
|
||||
Red: "\u001b[91m",
|
||||
|
|
188
ctx.go
188
ctx.go
|
@ -27,10 +27,16 @@ import (
|
|||
"github.com/gofiber/fiber/v2/internal/dictpool"
|
||||
"github.com/gofiber/fiber/v2/internal/schema"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/bytebufferpool"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
const (
|
||||
schemeHTTP = "http"
|
||||
schemeHTTPS = "https"
|
||||
)
|
||||
|
||||
// maxParams defines the maximum number of parameters per route.
|
||||
const maxParams = 30
|
||||
|
||||
|
@ -45,6 +51,7 @@ const (
|
|||
// userContextKey define the key name for storing context.Context in *fasthttp.RequestCtx
|
||||
const userContextKey = "__local_user_context__"
|
||||
|
||||
//nolint:gochecknoglobals // TODO: Do not use global vars here
|
||||
var (
|
||||
// decoderPoolMap helps to improve BodyParser's, QueryParser's and ReqHeaderParser's performance
|
||||
decoderPoolMap = map[string]*sync.Pool{}
|
||||
|
@ -52,6 +59,7 @@ var (
|
|||
tags = []string{queryTag, bodyTag, reqHeaderTag, paramsTag}
|
||||
)
|
||||
|
||||
//nolint:gochecknoinits // init() is used to initialize a global map variable
|
||||
func init() {
|
||||
for _, tag := range tags {
|
||||
decoderPoolMap[tag] = &sync.Pool{New: func() interface{} {
|
||||
|
@ -100,9 +108,10 @@ type TLSHandler struct {
|
|||
}
|
||||
|
||||
// GetClientInfo Callback function to set CHI
|
||||
// TODO: Why is this a getter which sets stuff?
|
||||
func (t *TLSHandler) GetClientInfo(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
t.clientHelloInfo = info
|
||||
return nil, nil
|
||||
return nil, nil //nolint:nilnil // Not returning anything useful here is probably fine
|
||||
}
|
||||
|
||||
// Range data for c.Range
|
||||
|
@ -151,7 +160,10 @@ type ParserConfig struct {
|
|||
|
||||
// AcquireCtx retrieves a new Ctx from the pool.
|
||||
func (app *App) AcquireCtx(fctx *fasthttp.RequestCtx) *Ctx {
|
||||
c := app.pool.Get().(*Ctx)
|
||||
c, ok := app.pool.Get().(*Ctx)
|
||||
if !ok {
|
||||
panic(fmt.Errorf("failed to type-assert to *Ctx"))
|
||||
}
|
||||
// Set app reference
|
||||
c.app = app
|
||||
// Reset route and handler index
|
||||
|
@ -388,7 +400,6 @@ func (c *Ctx) BodyParser(out interface{}) error {
|
|||
} else {
|
||||
data[k] = append(data[k], v)
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
return c.parseToStruct(bodyTag, out, data)
|
||||
|
@ -401,7 +412,10 @@ func (c *Ctx) BodyParser(out interface{}) error {
|
|||
return c.parseToStruct(bodyTag, out, data.Value)
|
||||
}
|
||||
if strings.HasPrefix(ctype, MIMETextXML) || strings.HasPrefix(ctype, MIMEApplicationXML) {
|
||||
return xml.Unmarshal(c.Body(), out)
|
||||
if err := xml.Unmarshal(c.Body(), out); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// No suitable content type found
|
||||
return ErrUnprocessableEntity
|
||||
|
@ -673,8 +687,11 @@ func (c *Ctx) Hostname() string {
|
|||
|
||||
// Port returns the remote port of the request.
|
||||
func (c *Ctx) Port() string {
|
||||
port := c.fasthttp.RemoteAddr().(*net.TCPAddr).Port
|
||||
return strconv.Itoa(port)
|
||||
tcpaddr, ok := c.fasthttp.RemoteAddr().(*net.TCPAddr)
|
||||
if !ok {
|
||||
panic(fmt.Errorf("failed to type-assert to *net.TCPAddr"))
|
||||
}
|
||||
return strconv.Itoa(tcpaddr.Port)
|
||||
}
|
||||
|
||||
// IP returns the remote IP address of the request.
|
||||
|
@ -691,13 +708,16 @@ func (c *Ctx) IP() string {
|
|||
// extractIPsFromHeader will return a slice of IPs it found given a header name in the order they appear.
|
||||
// When IP validation is enabled, any invalid IPs will be omitted.
|
||||
func (c *Ctx) extractIPsFromHeader(header string) []string {
|
||||
// TODO: Reuse the c.extractIPFromHeader func somehow in here
|
||||
|
||||
headerValue := c.Get(header)
|
||||
|
||||
// We can't know how many IPs we will return, but we will try to guess with this constant division.
|
||||
// Counting ',' makes function slower for about 50ns in general case.
|
||||
estimatedCount := len(headerValue) / 8
|
||||
if estimatedCount > 8 {
|
||||
estimatedCount = 8 // Avoid big allocation on big header
|
||||
const maxEstimatedCount = 8
|
||||
estimatedCount := len(headerValue) / maxEstimatedCount
|
||||
if estimatedCount > maxEstimatedCount {
|
||||
estimatedCount = maxEstimatedCount // Avoid big allocation on big header
|
||||
}
|
||||
|
||||
ipsFound := make([]string, 0, estimatedCount)
|
||||
|
@ -707,11 +727,10 @@ func (c *Ctx) extractIPsFromHeader(header string) []string {
|
|||
|
||||
iploop:
|
||||
for {
|
||||
v4 := false
|
||||
v6 := false
|
||||
var v4, v6 bool
|
||||
|
||||
// Manually splitting string without allocating slice, working with parts directly
|
||||
i, j = j+1, j+2
|
||||
i, j = j+1, j+2 //nolint:gomnd // Using these values is fine
|
||||
|
||||
if j > len(headerValue) {
|
||||
break
|
||||
|
@ -758,9 +777,10 @@ func (c *Ctx) extractIPFromHeader(header string) string {
|
|||
|
||||
iploop:
|
||||
for {
|
||||
v4 := false
|
||||
v6 := false
|
||||
i, j = j+1, j+2
|
||||
var v4, v6 bool
|
||||
|
||||
// Manually splitting string without allocating slice, working with parts directly
|
||||
i, j = j+1, j+2 //nolint:gomnd // Using these values is fine
|
||||
|
||||
if j > len(headerValue) {
|
||||
break
|
||||
|
@ -793,14 +813,14 @@ func (c *Ctx) extractIPFromHeader(header string) string {
|
|||
return c.fasthttp.RemoteIP().String()
|
||||
}
|
||||
|
||||
// default behaviour if IP validation is not enabled is just to return whatever value is
|
||||
// default behavior if IP validation is not enabled is just to return whatever value is
|
||||
// in the proxy header. Even if it is empty or invalid
|
||||
return c.Get(c.app.config.ProxyHeader)
|
||||
}
|
||||
|
||||
// IPs returns a string slice of IP addresses specified in the X-Forwarded-For request header.
|
||||
// When IP validation is enabled, only valid IPs are returned.
|
||||
func (c *Ctx) IPs() (ips []string) {
|
||||
func (c *Ctx) IPs() []string {
|
||||
return c.extractIPsFromHeader(HeaderXForwardedFor)
|
||||
}
|
||||
|
||||
|
@ -839,7 +859,7 @@ func (c *Ctx) JSON(data interface{}) error {
|
|||
func (c *Ctx) JSONP(data interface{}, callback ...string) error {
|
||||
raw, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to marshal: %w", err)
|
||||
}
|
||||
|
||||
var result, cb string
|
||||
|
@ -877,11 +897,11 @@ func (c *Ctx) Links(link ...string) {
|
|||
bb := bytebufferpool.Get()
|
||||
for i := range link {
|
||||
if i%2 == 0 {
|
||||
_ = bb.WriteByte('<')
|
||||
_, _ = bb.WriteString(link[i])
|
||||
_ = bb.WriteByte('>')
|
||||
_ = bb.WriteByte('<') //nolint:errcheck // This will never fail
|
||||
_, _ = bb.WriteString(link[i]) //nolint:errcheck // This will never fail
|
||||
_ = bb.WriteByte('>') //nolint:errcheck // This will never fail
|
||||
} else {
|
||||
_, _ = bb.WriteString(`; rel="` + link[i] + `",`)
|
||||
_, _ = bb.WriteString(`; rel="` + link[i] + `",`) //nolint:errcheck // This will never fail
|
||||
}
|
||||
}
|
||||
c.setCanonical(HeaderLink, utils.TrimRight(c.app.getString(bb.Bytes()), ','))
|
||||
|
@ -890,7 +910,7 @@ func (c *Ctx) Links(link ...string) {
|
|||
|
||||
// Locals makes it possible to pass interface{} values under keys scoped to the request
|
||||
// and therefore available to all following routes that match the request.
|
||||
func (c *Ctx) Locals(key interface{}, value ...interface{}) (val interface{}) {
|
||||
func (c *Ctx) Locals(key interface{}, value ...interface{}) interface{} {
|
||||
if len(value) == 0 {
|
||||
return c.fasthttp.UserValue(key)
|
||||
}
|
||||
|
@ -933,9 +953,10 @@ func (c *Ctx) ClientHelloInfo() *tls.ClientHelloInfo {
|
|||
}
|
||||
|
||||
// Next executes the next method in the stack that matches the current route.
|
||||
func (c *Ctx) Next() (err error) {
|
||||
func (c *Ctx) Next() error {
|
||||
// Increment handler index
|
||||
c.indexHandler++
|
||||
var err error
|
||||
// Did we executed all route handlers?
|
||||
if c.indexHandler < len(c.route.Handlers) {
|
||||
// Continue route stack
|
||||
|
@ -947,7 +968,7 @@ func (c *Ctx) Next() (err error) {
|
|||
return err
|
||||
}
|
||||
|
||||
// RestartRouting instead of going to the next handler. This may be usefull after
|
||||
// RestartRouting instead of going to the next handler. This may be useful after
|
||||
// changing the request path. Note that handlers might be executed again.
|
||||
func (c *Ctx) RestartRouting() error {
|
||||
c.indexRoute = -1
|
||||
|
@ -1017,9 +1038,8 @@ func (c *Ctx) ParamsInt(key string, defaultValue ...int) (int, error) {
|
|||
if err != nil {
|
||||
if len(defaultValue) > 0 {
|
||||
return defaultValue[0], nil
|
||||
} else {
|
||||
return 0, err
|
||||
}
|
||||
return 0, fmt.Errorf("failed to convert: %w", err)
|
||||
}
|
||||
|
||||
return value, nil
|
||||
|
@ -1044,15 +1064,16 @@ func (c *Ctx) Path(override ...string) string {
|
|||
// Please use Config.EnableTrustedProxyCheck to prevent header spoofing, in case when your app is behind the proxy.
|
||||
func (c *Ctx) Protocol() string {
|
||||
if c.fasthttp.IsTLS() {
|
||||
return "https"
|
||||
return schemeHTTPS
|
||||
}
|
||||
if !c.IsProxyTrusted() {
|
||||
return "http"
|
||||
return schemeHTTP
|
||||
}
|
||||
|
||||
scheme := "http"
|
||||
scheme := schemeHTTP
|
||||
const lenXHeaderName = 12
|
||||
c.fasthttp.Request.Header.VisitAll(func(key, val []byte) {
|
||||
if len(key) < 12 {
|
||||
if len(key) < lenXHeaderName {
|
||||
return // Neither "X-Forwarded-" nor "X-Url-Scheme"
|
||||
}
|
||||
switch {
|
||||
|
@ -1067,7 +1088,7 @@ func (c *Ctx) Protocol() string {
|
|||
scheme = v
|
||||
}
|
||||
} else if bytes.Equal(key, []byte(HeaderXForwardedSsl)) && bytes.Equal(val, []byte("on")) {
|
||||
scheme = "https"
|
||||
scheme = schemeHTTPS
|
||||
}
|
||||
|
||||
case bytes.Equal(key, []byte(HeaderXUrlScheme)):
|
||||
|
@ -1100,9 +1121,8 @@ func (c *Ctx) QueryInt(key string, defaultValue ...int) int {
|
|||
if err != nil {
|
||||
if len(defaultValue) > 0 {
|
||||
return defaultValue[0]
|
||||
} else {
|
||||
return 0
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
return value
|
||||
|
@ -1133,7 +1153,6 @@ func (c *Ctx) QueryParser(out interface{}) error {
|
|||
} else {
|
||||
data[k] = append(data[k], v)
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
|
@ -1150,10 +1169,9 @@ func parseParamSquareBrackets(k string) (string, error) {
|
|||
kbytes := []byte(k)
|
||||
|
||||
for i, b := range kbytes {
|
||||
|
||||
if b == '[' && kbytes[i+1] != ']' {
|
||||
if err := bb.WriteByte('.'); err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("failed to write: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1162,7 +1180,7 @@ func parseParamSquareBrackets(k string) (string, error) {
|
|||
}
|
||||
|
||||
if err := bb.WriteByte(b); err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("failed to write: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1184,21 +1202,27 @@ func (c *Ctx) ReqHeaderParser(out interface{}) error {
|
|||
} else {
|
||||
data[k] = append(data[k], v)
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
return c.parseToStruct(reqHeaderTag, out, data)
|
||||
}
|
||||
|
||||
func (c *Ctx) parseToStruct(aliasTag string, out interface{}, data map[string][]string) error {
|
||||
func (*Ctx) parseToStruct(aliasTag string, out interface{}, data map[string][]string) error {
|
||||
// Get decoder from pool
|
||||
schemaDecoder := decoderPoolMap[aliasTag].Get().(*schema.Decoder)
|
||||
schemaDecoder, ok := decoderPoolMap[aliasTag].Get().(*schema.Decoder)
|
||||
if !ok {
|
||||
panic(fmt.Errorf("failed to type-assert to *schema.Decoder"))
|
||||
}
|
||||
defer decoderPoolMap[aliasTag].Put(schemaDecoder)
|
||||
|
||||
// Set alias tag
|
||||
schemaDecoder.SetAliasTag(aliasTag)
|
||||
|
||||
return schemaDecoder.Decode(out, data)
|
||||
if err := schemaDecoder.Decode(out, data); err != nil {
|
||||
return fmt.Errorf("failed to decode: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func equalFieldType(out interface{}, kind reflect.Kind, key string) bool {
|
||||
|
@ -1248,24 +1272,23 @@ var (
|
|||
)
|
||||
|
||||
// Range returns a struct containing the type and a slice of ranges.
|
||||
func (c *Ctx) Range(size int) (rangeData Range, err error) {
|
||||
func (c *Ctx) Range(size int) (Range, error) {
|
||||
var rangeData Range
|
||||
rangeStr := c.Get(HeaderRange)
|
||||
if rangeStr == "" || !strings.Contains(rangeStr, "=") {
|
||||
err = ErrRangeMalformed
|
||||
return
|
||||
return rangeData, ErrRangeMalformed
|
||||
}
|
||||
data := strings.Split(rangeStr, "=")
|
||||
if len(data) != 2 {
|
||||
err = ErrRangeMalformed
|
||||
return
|
||||
const expectedDataParts = 2
|
||||
if len(data) != expectedDataParts {
|
||||
return rangeData, ErrRangeMalformed
|
||||
}
|
||||
rangeData.Type = data[0]
|
||||
arr := strings.Split(data[1], ",")
|
||||
for i := 0; i < len(arr); i++ {
|
||||
item := strings.Split(arr[i], "-")
|
||||
if len(item) == 1 {
|
||||
err = ErrRangeMalformed
|
||||
return
|
||||
return rangeData, ErrRangeMalformed
|
||||
}
|
||||
start, startErr := strconv.Atoi(item[0])
|
||||
end, endErr := strconv.Atoi(item[1])
|
||||
|
@ -1290,11 +1313,10 @@ func (c *Ctx) Range(size int) (rangeData Range, err error) {
|
|||
})
|
||||
}
|
||||
if len(rangeData.Ranges) < 1 {
|
||||
err = ErrRangeUnsatisfiable
|
||||
return
|
||||
return rangeData, ErrRangeUnsatisfiable
|
||||
}
|
||||
|
||||
return
|
||||
return rangeData, nil
|
||||
}
|
||||
|
||||
// Redirect to the URL derived from the specified path, with specified status.
|
||||
|
@ -1330,7 +1352,7 @@ func (c *Ctx) getLocationFromRoute(route Route, params Map) (string, error) {
|
|||
if !segment.IsParam {
|
||||
_, err := buf.WriteString(segment.Const)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("failed to write string: %w", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
@ -1341,7 +1363,7 @@ func (c *Ctx) getLocationFromRoute(route Route, params Map) (string, error) {
|
|||
if isSame || isGreedy {
|
||||
_, err := buf.WriteString(utils.ToString(val))
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("failed to write string: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1373,10 +1395,10 @@ func (c *Ctx) RedirectToRoute(routeName string, params Map, status ...int) error
|
|||
|
||||
i := 1
|
||||
for k, v := range queries {
|
||||
_, _ = queryText.WriteString(k + "=" + v)
|
||||
_, _ = queryText.WriteString(k + "=" + v) //nolint:errcheck // This will never fail
|
||||
|
||||
if i != len(queries) {
|
||||
_, _ = queryText.WriteString("&")
|
||||
_, _ = queryText.WriteString("&") //nolint:errcheck // This will never fail
|
||||
}
|
||||
i++
|
||||
}
|
||||
|
@ -1399,7 +1421,6 @@ func (c *Ctx) RedirectBack(fallback string, status ...int) error {
|
|||
// Render a template with data and sends a text/html response.
|
||||
// We support the following engines: html, amber, handlebars, mustache, pug
|
||||
func (c *Ctx) Render(name string, bind interface{}, layouts ...string) error {
|
||||
var err error
|
||||
// Get new buffer from pool
|
||||
buf := bytebufferpool.Get()
|
||||
defer bytebufferpool.Put(buf)
|
||||
|
@ -1421,7 +1442,7 @@ func (c *Ctx) Render(name string, bind interface{}, layouts ...string) error {
|
|||
// Render template from Views
|
||||
if app.config.Views != nil {
|
||||
if err := app.config.Views.Render(buf, name, bind, layouts...); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to render: %w", err)
|
||||
}
|
||||
|
||||
rendered = true
|
||||
|
@ -1433,17 +1454,18 @@ func (c *Ctx) Render(name string, bind interface{}, layouts ...string) error {
|
|||
if !rendered {
|
||||
// Render raw template using 'name' as filepath if no engine is set
|
||||
var tmpl *template.Template
|
||||
if _, err = readContent(buf, name); err != nil {
|
||||
if _, err := readContent(buf, name); err != nil {
|
||||
return err
|
||||
}
|
||||
// Parse template
|
||||
if tmpl, err = template.New("").Parse(c.app.getString(buf.Bytes())); err != nil {
|
||||
return err
|
||||
tmpl, err := template.New("").Parse(c.app.getString(buf.Bytes()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse: %w", err)
|
||||
}
|
||||
buf.Reset()
|
||||
// Render template
|
||||
if err = tmpl.Execute(buf, bind); err != nil {
|
||||
return err
|
||||
if err := tmpl.Execute(buf, bind); err != nil {
|
||||
return fmt.Errorf("failed to execute: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1451,8 +1473,8 @@ func (c *Ctx) Render(name string, bind interface{}, layouts ...string) error {
|
|||
c.fasthttp.Response.Header.SetContentType(MIMETextHTMLCharsetUTF8)
|
||||
// Set rendered template to body
|
||||
c.fasthttp.Response.SetBody(buf.Bytes())
|
||||
// Return err if exist
|
||||
return err
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Ctx) renderExtensions(bind interface{}) {
|
||||
|
@ -1501,28 +1523,32 @@ func (c *Ctx) Route() *Route {
|
|||
}
|
||||
|
||||
// SaveFile saves any multipart file to disk.
|
||||
func (c *Ctx) SaveFile(fileheader *multipart.FileHeader, path string) error {
|
||||
func (*Ctx) SaveFile(fileheader *multipart.FileHeader, path string) error {
|
||||
return fasthttp.SaveMultipartFile(fileheader, path)
|
||||
}
|
||||
|
||||
// SaveFileToStorage saves any multipart file to an external storage system.
|
||||
func (c *Ctx) SaveFileToStorage(fileheader *multipart.FileHeader, path string, storage Storage) error {
|
||||
func (*Ctx) SaveFileToStorage(fileheader *multipart.FileHeader, path string, storage Storage) error {
|
||||
file, err := fileheader.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to open: %w", err)
|
||||
}
|
||||
|
||||
content, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to read: %w", err)
|
||||
}
|
||||
|
||||
return storage.Set(path, content, 0)
|
||||
if err := storage.Set(path, content, 0); err != nil {
|
||||
return fmt.Errorf("failed to store: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Secure returns whether a secure connection was established.
|
||||
func (c *Ctx) Secure() bool {
|
||||
return c.Protocol() == "https"
|
||||
return c.Protocol() == schemeHTTPS
|
||||
}
|
||||
|
||||
// Send sets the HTTP response body without copying it.
|
||||
|
@ -1533,6 +1559,7 @@ func (c *Ctx) Send(body []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
//nolint:gochecknoglobals // TODO: Do not use global vars here
|
||||
var (
|
||||
sendFileOnce sync.Once
|
||||
sendFileFS *fasthttp.FS
|
||||
|
@ -1548,6 +1575,7 @@ func (c *Ctx) SendFile(file string, compress ...bool) error {
|
|||
|
||||
// https://github.com/valyala/fasthttp/blob/c7576cc10cabfc9c993317a2d3f8355497bea156/fs.go#L129-L134
|
||||
sendFileOnce.Do(func() {
|
||||
const cacheDuration = 10 * time.Second
|
||||
sendFileFS = &fasthttp.FS{
|
||||
Root: "",
|
||||
AllowEmptyRoot: true,
|
||||
|
@ -1555,7 +1583,7 @@ func (c *Ctx) SendFile(file string, compress ...bool) error {
|
|||
AcceptByteRange: true,
|
||||
Compress: true,
|
||||
CompressedFileSuffix: c.app.config.CompressedFileSuffix,
|
||||
CacheDuration: 10 * time.Second,
|
||||
CacheDuration: cacheDuration,
|
||||
IndexNames: []string{"index.html"},
|
||||
PathNotFound: func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.Response.SetStatusCode(StatusNotFound)
|
||||
|
@ -1579,7 +1607,7 @@ func (c *Ctx) SendFile(file string, compress ...bool) error {
|
|||
var err error
|
||||
file = filepath.FromSlash(file)
|
||||
if file, err = filepath.Abs(file); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to determine abs file path: %w", err)
|
||||
}
|
||||
if hasTrailingSlash {
|
||||
file += "/"
|
||||
|
@ -1644,11 +1672,11 @@ func (c *Ctx) SendStream(stream io.Reader, size ...int) error {
|
|||
}
|
||||
|
||||
// Set sets the response's HTTP header field to the specified key, value.
|
||||
func (c *Ctx) Set(key string, val string) {
|
||||
func (c *Ctx) Set(key, val string) {
|
||||
c.fasthttp.Response.Header.Set(key, val)
|
||||
}
|
||||
|
||||
func (c *Ctx) setCanonical(key string, val string) {
|
||||
func (c *Ctx) setCanonical(key, val string) {
|
||||
c.fasthttp.Response.Header.SetCanonical(c.app.getBytes(key), c.app.getBytes(val))
|
||||
}
|
||||
|
||||
|
@ -1719,6 +1747,7 @@ func (c *Ctx) Write(p []byte) (int, error) {
|
|||
|
||||
// Writef appends f & a into response body writer.
|
||||
func (c *Ctx) Writef(f string, a ...interface{}) (int, error) {
|
||||
//nolint:wrapcheck // This must not be wrapped
|
||||
return fmt.Fprintf(c.fasthttp.Response.BodyWriter(), f, a...)
|
||||
}
|
||||
|
||||
|
@ -1760,8 +1789,9 @@ func (c *Ctx) configDependentPaths() {
|
|||
// Define the path for dividing routes into areas for fast tree detection, so that fewer routes need to be traversed,
|
||||
// since the first three characters area select a list of routes
|
||||
c.treePath = c.treePath[0:0]
|
||||
if len(c.detectionPath) >= 3 {
|
||||
c.treePath = c.detectionPath[:3]
|
||||
const maxDetectionPaths = 3
|
||||
if len(c.detectionPath) >= maxDetectionPaths {
|
||||
c.treePath = c.detectionPath[:maxDetectionPaths]
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1786,7 +1816,7 @@ func (c *Ctx) IsProxyTrusted() bool {
|
|||
}
|
||||
|
||||
// IsLocalHost will return true if address is a localhost address.
|
||||
func (c *Ctx) isLocalHost(address string) bool {
|
||||
func (*Ctx) isLocalHost(address string) bool {
|
||||
localHosts := []string{"127.0.0.1", "0.0.0.0", "::1"}
|
||||
for _, h := range localHosts {
|
||||
if strings.Contains(address, h) {
|
||||
|
|
314
ctx_test.go
314
ctx_test.go
|
@ -2,6 +2,7 @@
|
|||
// 🤖 Github Repository: https://github.com/gofiber/fiber
|
||||
// 📌 API Documentation: https://docs.gofiber.io
|
||||
|
||||
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
|
||||
package fiber
|
||||
|
||||
import (
|
||||
|
@ -28,6 +29,7 @@ import (
|
|||
|
||||
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/bytebufferpool"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
@ -361,8 +363,10 @@ func Test_Ctx_BodyParser(t *testing.T) {
|
|||
{
|
||||
var gzipJSON bytes.Buffer
|
||||
w := gzip.NewWriter(&gzipJSON)
|
||||
_, _ = w.Write([]byte(`{"name":"john"}`))
|
||||
_ = w.Close()
|
||||
_, err := w.Write([]byte(`{"name":"john"}`))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
err = w.Close()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
c.Request().Header.SetContentType(MIMEApplicationJSON)
|
||||
c.Request().Header.Set(HeaderContentEncoding, "gzip")
|
||||
|
@ -431,9 +435,7 @@ func Test_Ctx_ParamParser(t *testing.T) {
|
|||
UserID uint `params:"userId"`
|
||||
RoleID uint `params:"roleId"`
|
||||
}
|
||||
var (
|
||||
d = new(Demo)
|
||||
)
|
||||
d := new(Demo)
|
||||
if err := ctx.ParamsParser(d); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -519,7 +521,7 @@ func Benchmark_Ctx_BodyParser_JSON(b *testing.B) {
|
|||
b.ResetTimer()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
_ = c.BodyParser(d)
|
||||
_ = c.BodyParser(d) //nolint:errcheck // It is fine to ignore the error here as we check it once further below
|
||||
}
|
||||
utils.AssertEqual(b, nil, c.BodyParser(d))
|
||||
utils.AssertEqual(b, "john", d.Name)
|
||||
|
@ -543,7 +545,7 @@ func Benchmark_Ctx_BodyParser_XML(b *testing.B) {
|
|||
b.ResetTimer()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
_ = c.BodyParser(d)
|
||||
_ = c.BodyParser(d) //nolint:errcheck // It is fine to ignore the error here as we check it once further below
|
||||
}
|
||||
utils.AssertEqual(b, nil, c.BodyParser(d))
|
||||
utils.AssertEqual(b, "john", d.Name)
|
||||
|
@ -567,7 +569,7 @@ func Benchmark_Ctx_BodyParser_Form(b *testing.B) {
|
|||
b.ResetTimer()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
_ = c.BodyParser(d)
|
||||
_ = c.BodyParser(d) //nolint:errcheck // It is fine to ignore the error here as we check it once further below
|
||||
}
|
||||
utils.AssertEqual(b, nil, c.BodyParser(d))
|
||||
utils.AssertEqual(b, "john", d.Name)
|
||||
|
@ -592,7 +594,7 @@ func Benchmark_Ctx_BodyParser_MultipartForm(b *testing.B) {
|
|||
b.ResetTimer()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
_ = c.BodyParser(d)
|
||||
_ = c.BodyParser(d) //nolint:errcheck // It is fine to ignore the error here as we check it once further below
|
||||
}
|
||||
utils.AssertEqual(b, nil, c.BodyParser(d))
|
||||
utils.AssertEqual(b, "john", d.Name)
|
||||
|
@ -879,12 +881,13 @@ func Test_Ctx_FormFile(t *testing.T) {
|
|||
|
||||
f, err := fh.Open()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
defer func() {
|
||||
utils.AssertEqual(t, nil, f.Close())
|
||||
}()
|
||||
|
||||
b := new(bytes.Buffer)
|
||||
_, err = io.Copy(b, f)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
f.Close()
|
||||
utils.AssertEqual(t, "hello world", b.String())
|
||||
return nil
|
||||
})
|
||||
|
@ -897,8 +900,7 @@ func Test_Ctx_FormFile(t *testing.T) {
|
|||
|
||||
_, err = ioWriter.Write([]byte("hello world"))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
writer.Close()
|
||||
utils.AssertEqual(t, nil, writer.Close())
|
||||
|
||||
req := httptest.NewRequest(MethodPost, "/test", body)
|
||||
req.Header.Set(HeaderContentType, writer.FormDataContentType())
|
||||
|
@ -921,10 +923,9 @@ func Test_Ctx_FormValue(t *testing.T) {
|
|||
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
utils.AssertEqual(t, nil, writer.WriteField("name", "john"))
|
||||
utils.AssertEqual(t, nil, writer.Close())
|
||||
|
||||
writer.Close()
|
||||
req := httptest.NewRequest(MethodPost, "/test", body)
|
||||
req.Header.Set("Content-Type", fmt.Sprintf("multipart/form-data; boundary=%s", writer.Boundary()))
|
||||
req.Header.Set("Content-Length", strconv.Itoa(len(body.Bytes())))
|
||||
|
@ -1240,7 +1241,7 @@ func Test_Ctx_IP(t *testing.T) {
|
|||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(c)
|
||||
|
||||
// default behaviour will return the remote IP from the stack
|
||||
// default behavior will return the remote IP from the stack
|
||||
utils.AssertEqual(t, "0.0.0.0", c.IP())
|
||||
|
||||
// X-Forwarded-For is set, but it is ignored because proxyHeader is not set
|
||||
|
@ -1252,7 +1253,7 @@ func Test_Ctx_IP(t *testing.T) {
|
|||
func Test_Ctx_IP_ProxyHeader(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// make sure that the same behaviour exists for different proxy header names
|
||||
// make sure that the same behavior exists for different proxy header names
|
||||
proxyHeaderNames := []string{"Real-Ip", HeaderXForwardedFor}
|
||||
|
||||
for _, proxyHeaderName := range proxyHeaderNames {
|
||||
|
@ -1286,7 +1287,7 @@ func Test_Ctx_IP_ProxyHeader(t *testing.T) {
|
|||
func Test_Ctx_IP_ProxyHeader_With_IP_Validation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// make sure that the same behaviour exists for different proxy header names
|
||||
// make sure that the same behavior exists for different proxy header names
|
||||
proxyHeaderNames := []string{"Real-Ip", HeaderXForwardedFor}
|
||||
|
||||
for _, proxyHeaderName := range proxyHeaderNames {
|
||||
|
@ -1625,35 +1626,43 @@ func Test_Ctx_ClientHelloInfo(t *testing.T) {
|
|||
})
|
||||
|
||||
// Test without TLS handler
|
||||
resp, _ := app.Test(httptest.NewRequest(MethodGet, "/ServerName", nil))
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
resp, err := app.Test(httptest.NewRequest(MethodGet, "/ServerName", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, []byte("ClientHelloInfo is nil"), body)
|
||||
|
||||
// Test with TLS Handler
|
||||
const (
|
||||
PSSWithSHA256 = 0x0804
|
||||
VersionTLS13 = 0x0304
|
||||
pssWithSHA256 = 0x0804
|
||||
versionTLS13 = 0x0304
|
||||
)
|
||||
app.tlsHandler = &TLSHandler{clientHelloInfo: &tls.ClientHelloInfo{
|
||||
ServerName: "example.golang",
|
||||
SignatureSchemes: []tls.SignatureScheme{PSSWithSHA256},
|
||||
SupportedVersions: []uint16{VersionTLS13},
|
||||
SignatureSchemes: []tls.SignatureScheme{pssWithSHA256},
|
||||
SupportedVersions: []uint16{versionTLS13},
|
||||
}}
|
||||
|
||||
// Test ServerName
|
||||
resp, _ = app.Test(httptest.NewRequest(MethodGet, "/ServerName", nil))
|
||||
body, _ = io.ReadAll(resp.Body)
|
||||
resp, err = app.Test(httptest.NewRequest(MethodGet, "/ServerName", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, []byte("example.golang"), body)
|
||||
|
||||
// Test SignatureSchemes
|
||||
resp, _ = app.Test(httptest.NewRequest(MethodGet, "/SignatureSchemes", nil))
|
||||
body, _ = io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, "["+strconv.Itoa(PSSWithSHA256)+"]", string(body))
|
||||
resp, err = app.Test(httptest.NewRequest(MethodGet, "/SignatureSchemes", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "["+strconv.Itoa(pssWithSHA256)+"]", string(body))
|
||||
|
||||
// Test SupportedVersions
|
||||
resp, _ = app.Test(httptest.NewRequest(MethodGet, "/SupportedVersions", nil))
|
||||
body, _ = io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, "["+strconv.Itoa(VersionTLS13)+"]", string(body))
|
||||
resp, err = app.Test(httptest.NewRequest(MethodGet, "/SupportedVersions", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "["+strconv.Itoa(versionTLS13)+"]", string(body))
|
||||
}
|
||||
|
||||
// go test -run Test_Ctx_InvalidMethod
|
||||
|
@ -1688,10 +1697,9 @@ func Test_Ctx_MultipartForm(t *testing.T) {
|
|||
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
utils.AssertEqual(t, nil, writer.WriteField("name", "john"))
|
||||
utils.AssertEqual(t, nil, writer.Close())
|
||||
|
||||
writer.Close()
|
||||
req := httptest.NewRequest(MethodPost, "/test", body)
|
||||
req.Header.Set(HeaderContentType, fmt.Sprintf("multipart/form-data; boundary=%s", writer.Boundary()))
|
||||
req.Header.Set(HeaderContentLength, strconv.Itoa(len(body.Bytes())))
|
||||
|
@ -1706,8 +1714,8 @@ func Benchmark_Ctx_MultipartForm(b *testing.B) {
|
|||
app := New()
|
||||
|
||||
app.Post("/", func(c *Ctx) error {
|
||||
_, _ = c.MultipartForm()
|
||||
return nil
|
||||
_, err := c.MultipartForm()
|
||||
return err
|
||||
})
|
||||
|
||||
c := &fasthttp.RequestCtx{}
|
||||
|
@ -1889,11 +1897,16 @@ func Benchmark_Ctx_AllParams(b *testing.B) {
|
|||
for n := 0; n < b.N; n++ {
|
||||
res = c.AllParams()
|
||||
}
|
||||
utils.AssertEqual(b, map[string]string{"param1": "john",
|
||||
"param2": "doe",
|
||||
"param3": "is",
|
||||
"param4": "awesome"},
|
||||
res)
|
||||
utils.AssertEqual(
|
||||
b,
|
||||
map[string]string{
|
||||
"param1": "john",
|
||||
"param2": "doe",
|
||||
"param3": "is",
|
||||
"param4": "awesome",
|
||||
},
|
||||
res,
|
||||
)
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Ctx_ParamsParse -benchmem -count=4
|
||||
|
@ -1964,31 +1977,31 @@ func Test_Ctx_Protocol(t *testing.T) {
|
|||
|
||||
c := app.AcquireCtx(freq)
|
||||
defer app.ReleaseCtx(c)
|
||||
c.Request().Header.Set(HeaderXForwardedProto, "https")
|
||||
utils.AssertEqual(t, "https", c.Protocol())
|
||||
c.Request().Header.Set(HeaderXForwardedProto, schemeHTTPS)
|
||||
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
|
||||
c.Request().Header.Reset()
|
||||
|
||||
c.Request().Header.Set(HeaderXForwardedProtocol, "https")
|
||||
utils.AssertEqual(t, "https", c.Protocol())
|
||||
c.Request().Header.Set(HeaderXForwardedProtocol, schemeHTTPS)
|
||||
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
|
||||
c.Request().Header.Reset()
|
||||
|
||||
c.Request().Header.Set(HeaderXForwardedProto, "https, http")
|
||||
utils.AssertEqual(t, "https", c.Protocol())
|
||||
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
|
||||
c.Request().Header.Reset()
|
||||
|
||||
c.Request().Header.Set(HeaderXForwardedProtocol, "https, http")
|
||||
utils.AssertEqual(t, "https", c.Protocol())
|
||||
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
|
||||
c.Request().Header.Reset()
|
||||
|
||||
c.Request().Header.Set(HeaderXForwardedSsl, "on")
|
||||
utils.AssertEqual(t, "https", c.Protocol())
|
||||
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
|
||||
c.Request().Header.Reset()
|
||||
|
||||
c.Request().Header.Set(HeaderXUrlScheme, "https")
|
||||
utils.AssertEqual(t, "https", c.Protocol())
|
||||
c.Request().Header.Set(HeaderXUrlScheme, schemeHTTPS)
|
||||
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
|
||||
c.Request().Header.Reset()
|
||||
|
||||
utils.AssertEqual(t, "http", c.Protocol())
|
||||
utils.AssertEqual(t, schemeHTTP, c.Protocol())
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Ctx_Protocol -benchmem -count=4
|
||||
|
@ -2002,7 +2015,7 @@ func Benchmark_Ctx_Protocol(b *testing.B) {
|
|||
for n := 0; n < b.N; n++ {
|
||||
res = c.Protocol()
|
||||
}
|
||||
utils.AssertEqual(b, "http", res)
|
||||
utils.AssertEqual(b, schemeHTTP, res)
|
||||
}
|
||||
|
||||
// go test -run Test_Ctx_Protocol_TrustedProxy
|
||||
|
@ -2012,23 +2025,23 @@ func Test_Ctx_Protocol_TrustedProxy(t *testing.T) {
|
|||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(c)
|
||||
|
||||
c.Request().Header.Set(HeaderXForwardedProto, "https")
|
||||
utils.AssertEqual(t, "https", c.Protocol())
|
||||
c.Request().Header.Set(HeaderXForwardedProto, schemeHTTPS)
|
||||
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
|
||||
c.Request().Header.Reset()
|
||||
|
||||
c.Request().Header.Set(HeaderXForwardedProtocol, "https")
|
||||
utils.AssertEqual(t, "https", c.Protocol())
|
||||
c.Request().Header.Set(HeaderXForwardedProtocol, schemeHTTPS)
|
||||
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
|
||||
c.Request().Header.Reset()
|
||||
|
||||
c.Request().Header.Set(HeaderXForwardedSsl, "on")
|
||||
utils.AssertEqual(t, "https", c.Protocol())
|
||||
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
|
||||
c.Request().Header.Reset()
|
||||
|
||||
c.Request().Header.Set(HeaderXUrlScheme, "https")
|
||||
utils.AssertEqual(t, "https", c.Protocol())
|
||||
c.Request().Header.Set(HeaderXUrlScheme, schemeHTTPS)
|
||||
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
|
||||
c.Request().Header.Reset()
|
||||
|
||||
utils.AssertEqual(t, "http", c.Protocol())
|
||||
utils.AssertEqual(t, schemeHTTP, c.Protocol())
|
||||
}
|
||||
|
||||
// go test -run Test_Ctx_Protocol_TrustedProxyRange
|
||||
|
@ -2038,23 +2051,23 @@ func Test_Ctx_Protocol_TrustedProxyRange(t *testing.T) {
|
|||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(c)
|
||||
|
||||
c.Request().Header.Set(HeaderXForwardedProto, "https")
|
||||
utils.AssertEqual(t, "https", c.Protocol())
|
||||
c.Request().Header.Set(HeaderXForwardedProto, schemeHTTPS)
|
||||
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
|
||||
c.Request().Header.Reset()
|
||||
|
||||
c.Request().Header.Set(HeaderXForwardedProtocol, "https")
|
||||
utils.AssertEqual(t, "https", c.Protocol())
|
||||
c.Request().Header.Set(HeaderXForwardedProtocol, schemeHTTPS)
|
||||
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
|
||||
c.Request().Header.Reset()
|
||||
|
||||
c.Request().Header.Set(HeaderXForwardedSsl, "on")
|
||||
utils.AssertEqual(t, "https", c.Protocol())
|
||||
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
|
||||
c.Request().Header.Reset()
|
||||
|
||||
c.Request().Header.Set(HeaderXUrlScheme, "https")
|
||||
utils.AssertEqual(t, "https", c.Protocol())
|
||||
c.Request().Header.Set(HeaderXUrlScheme, schemeHTTPS)
|
||||
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
|
||||
c.Request().Header.Reset()
|
||||
|
||||
utils.AssertEqual(t, "http", c.Protocol())
|
||||
utils.AssertEqual(t, schemeHTTP, c.Protocol())
|
||||
}
|
||||
|
||||
// go test -run Test_Ctx_Protocol_UntrustedProxyRange
|
||||
|
@ -2064,23 +2077,23 @@ func Test_Ctx_Protocol_UntrustedProxyRange(t *testing.T) {
|
|||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(c)
|
||||
|
||||
c.Request().Header.Set(HeaderXForwardedProto, "https")
|
||||
utils.AssertEqual(t, "http", c.Protocol())
|
||||
c.Request().Header.Set(HeaderXForwardedProto, schemeHTTPS)
|
||||
utils.AssertEqual(t, schemeHTTP, c.Protocol())
|
||||
c.Request().Header.Reset()
|
||||
|
||||
c.Request().Header.Set(HeaderXForwardedProtocol, "https")
|
||||
utils.AssertEqual(t, "http", c.Protocol())
|
||||
c.Request().Header.Set(HeaderXForwardedProtocol, schemeHTTPS)
|
||||
utils.AssertEqual(t, schemeHTTP, c.Protocol())
|
||||
c.Request().Header.Reset()
|
||||
|
||||
c.Request().Header.Set(HeaderXForwardedSsl, "on")
|
||||
utils.AssertEqual(t, "http", c.Protocol())
|
||||
utils.AssertEqual(t, schemeHTTP, c.Protocol())
|
||||
c.Request().Header.Reset()
|
||||
|
||||
c.Request().Header.Set(HeaderXUrlScheme, "https")
|
||||
utils.AssertEqual(t, "http", c.Protocol())
|
||||
c.Request().Header.Set(HeaderXUrlScheme, schemeHTTPS)
|
||||
utils.AssertEqual(t, schemeHTTP, c.Protocol())
|
||||
c.Request().Header.Reset()
|
||||
|
||||
utils.AssertEqual(t, "http", c.Protocol())
|
||||
utils.AssertEqual(t, schemeHTTP, c.Protocol())
|
||||
}
|
||||
|
||||
// go test -run Test_Ctx_Protocol_UnTrustedProxy
|
||||
|
@ -2090,23 +2103,23 @@ func Test_Ctx_Protocol_UnTrustedProxy(t *testing.T) {
|
|||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(c)
|
||||
|
||||
c.Request().Header.Set(HeaderXForwardedProto, "https")
|
||||
utils.AssertEqual(t, "http", c.Protocol())
|
||||
c.Request().Header.Set(HeaderXForwardedProto, schemeHTTPS)
|
||||
utils.AssertEqual(t, schemeHTTP, c.Protocol())
|
||||
c.Request().Header.Reset()
|
||||
|
||||
c.Request().Header.Set(HeaderXForwardedProtocol, "https")
|
||||
utils.AssertEqual(t, "http", c.Protocol())
|
||||
c.Request().Header.Set(HeaderXForwardedProtocol, schemeHTTPS)
|
||||
utils.AssertEqual(t, schemeHTTP, c.Protocol())
|
||||
c.Request().Header.Reset()
|
||||
|
||||
c.Request().Header.Set(HeaderXForwardedSsl, "on")
|
||||
utils.AssertEqual(t, "http", c.Protocol())
|
||||
utils.AssertEqual(t, schemeHTTP, c.Protocol())
|
||||
c.Request().Header.Reset()
|
||||
|
||||
c.Request().Header.Set(HeaderXUrlScheme, "https")
|
||||
utils.AssertEqual(t, "http", c.Protocol())
|
||||
c.Request().Header.Set(HeaderXUrlScheme, schemeHTTPS)
|
||||
utils.AssertEqual(t, schemeHTTP, c.Protocol())
|
||||
c.Request().Header.Reset()
|
||||
|
||||
utils.AssertEqual(t, "http", c.Protocol())
|
||||
utils.AssertEqual(t, schemeHTTP, c.Protocol())
|
||||
}
|
||||
|
||||
// go test -run Test_Ctx_Query
|
||||
|
@ -2224,7 +2237,12 @@ func Test_Ctx_SaveFile(t *testing.T) {
|
|||
tempFile, err := os.CreateTemp(os.TempDir(), "test-")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
defer os.Remove(tempFile.Name())
|
||||
defer func(file *os.File) {
|
||||
err := file.Close()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
err = os.Remove(file.Name())
|
||||
utils.AssertEqual(t, nil, err)
|
||||
}(tempFile)
|
||||
err = c.SaveFile(fh, tempFile.Name())
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
|
@ -2242,7 +2260,7 @@ func Test_Ctx_SaveFile(t *testing.T) {
|
|||
|
||||
_, err = ioWriter.Write([]byte("hello world"))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
writer.Close()
|
||||
utils.AssertEqual(t, nil, writer.Close())
|
||||
|
||||
req := httptest.NewRequest(MethodPost, "/test", body)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
|
@ -2284,7 +2302,7 @@ func Test_Ctx_SaveFileToStorage(t *testing.T) {
|
|||
|
||||
_, err = ioWriter.Write([]byte("hello world"))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
writer.Close()
|
||||
utils.AssertEqual(t, nil, writer.Close())
|
||||
|
||||
req := httptest.NewRequest(MethodPost, "/test", body)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
|
@ -2370,7 +2388,9 @@ func Test_Ctx_Download(t *testing.T) {
|
|||
|
||||
f, err := os.Open("./ctx.go")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
defer f.Close()
|
||||
defer func() {
|
||||
utils.AssertEqual(t, nil, f.Close())
|
||||
}()
|
||||
|
||||
expect, err := io.ReadAll(f)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
@ -2389,7 +2409,9 @@ func Test_Ctx_SendFile(t *testing.T) {
|
|||
// fetch file content
|
||||
f, err := os.Open("./ctx.go")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
defer f.Close()
|
||||
defer func() {
|
||||
utils.AssertEqual(t, nil, f.Close())
|
||||
}()
|
||||
expectFileContent, err := io.ReadAll(f)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
// fetch file info for the not modified test case
|
||||
|
@ -2435,7 +2457,7 @@ func Test_Ctx_SendFile_404(t *testing.T) {
|
|||
return err
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
@ -2473,11 +2495,11 @@ func Test_Ctx_SendFile_Immutable(t *testing.T) {
|
|||
for _, endpoint := range endpointsForTest {
|
||||
t.Run(endpoint, func(t *testing.T) {
|
||||
// 1st try
|
||||
resp, err := app.Test(httptest.NewRequest("GET", endpoint, nil))
|
||||
resp, err := app.Test(httptest.NewRequest(MethodGet, endpoint, nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, StatusOK, resp.StatusCode)
|
||||
// 2nd try
|
||||
resp, err = app.Test(httptest.NewRequest("GET", endpoint, nil))
|
||||
resp, err = app.Test(httptest.NewRequest(MethodGet, endpoint, nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, StatusOK, resp.StatusCode)
|
||||
})
|
||||
|
@ -2495,9 +2517,9 @@ func Test_Ctx_SendFile_RestoreOriginalURL(t *testing.T) {
|
|||
return err
|
||||
})
|
||||
|
||||
_, err1 := app.Test(httptest.NewRequest("GET", "/?test=true", nil))
|
||||
_, err1 := app.Test(httptest.NewRequest(MethodGet, "/?test=true", nil))
|
||||
// second request required to confirm with zero allocation
|
||||
_, err2 := app.Test(httptest.NewRequest("GET", "/?test=true", nil))
|
||||
_, err2 := app.Test(httptest.NewRequest(MethodGet, "/?test=true", nil))
|
||||
|
||||
utils.AssertEqual(t, nil, err1)
|
||||
utils.AssertEqual(t, nil, err2)
|
||||
|
@ -2893,12 +2915,12 @@ func Test_Ctx_Render(t *testing.T) {
|
|||
err := c.Render("./.github/testdata/index.tmpl", Map{
|
||||
"Title": "Hello, World!",
|
||||
})
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
buf := bytebufferpool.Get()
|
||||
_, _ = buf.WriteString("overwrite")
|
||||
_, _ = buf.WriteString("overwrite") //nolint:errcheck // This will never fail
|
||||
defer bytebufferpool.Put(buf)
|
||||
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "<h1>Hello, World!</h1>", string(c.Response().Body()))
|
||||
|
||||
err = c.Render("./.github/testdata/template-non-exists.html", nil)
|
||||
|
@ -2918,12 +2940,12 @@ func Test_Ctx_RenderWithoutLocals(t *testing.T) {
|
|||
c.Locals("Title", "Hello, World!")
|
||||
defer app.ReleaseCtx(c)
|
||||
err := c.Render("./.github/testdata/index.tmpl", Map{})
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
buf := bytebufferpool.Get()
|
||||
_, _ = buf.WriteString("overwrite")
|
||||
_, _ = buf.WriteString("overwrite") //nolint:errcheck // This will never fail
|
||||
defer bytebufferpool.Put(buf)
|
||||
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "<h1><no value></h1>", string(c.Response().Body()))
|
||||
}
|
||||
|
||||
|
@ -2937,14 +2959,13 @@ func Test_Ctx_RenderWithLocals(t *testing.T) {
|
|||
c.Locals("Title", "Hello, World!")
|
||||
defer app.ReleaseCtx(c)
|
||||
err := c.Render("./.github/testdata/index.tmpl", Map{})
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
buf := bytebufferpool.Get()
|
||||
_, _ = buf.WriteString("overwrite")
|
||||
_, _ = buf.WriteString("overwrite") //nolint:errcheck // This will never fail
|
||||
defer bytebufferpool.Put(buf)
|
||||
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "<h1>Hello, World!</h1>", string(c.Response().Body()))
|
||||
|
||||
}
|
||||
|
||||
func Test_Ctx_RenderWithBind(t *testing.T) {
|
||||
|
@ -2959,14 +2980,13 @@ func Test_Ctx_RenderWithBind(t *testing.T) {
|
|||
utils.AssertEqual(t, nil, err)
|
||||
defer app.ReleaseCtx(c)
|
||||
err = c.Render("./.github/testdata/index.tmpl", Map{})
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
buf := bytebufferpool.Get()
|
||||
_, _ = buf.WriteString("overwrite")
|
||||
_, _ = buf.WriteString("overwrite") //nolint:errcheck // This will never fail
|
||||
defer bytebufferpool.Put(buf)
|
||||
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "<h1>Hello, World!</h1>", string(c.Response().Body()))
|
||||
|
||||
}
|
||||
|
||||
func Test_Ctx_RenderWithOverwrittenBind(t *testing.T) {
|
||||
|
@ -2982,12 +3002,12 @@ func Test_Ctx_RenderWithOverwrittenBind(t *testing.T) {
|
|||
err = c.Render("./.github/testdata/index.tmpl", Map{
|
||||
"Title": "Hello from Fiber!",
|
||||
})
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
buf := bytebufferpool.Get()
|
||||
_, _ = buf.WriteString("overwrite")
|
||||
_, _ = buf.WriteString("overwrite") //nolint:errcheck // This will never fail
|
||||
defer bytebufferpool.Put(buf)
|
||||
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "<h1>Hello from Fiber!</h1>", string(c.Response().Body()))
|
||||
}
|
||||
|
||||
|
@ -3005,13 +3025,12 @@ func Test_Ctx_RenderWithBindLocals(t *testing.T) {
|
|||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
c.Locals("Summary", "Test")
|
||||
|
||||
defer app.ReleaseCtx(c)
|
||||
|
||||
err = c.Render("./.github/testdata/template.tmpl", Map{})
|
||||
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "<h1>Hello, World! Test</h1>", string(c.Response().Body()))
|
||||
|
||||
utils.AssertEqual(t, "<h1>Hello, World! Test</h1>", string(c.Response().Body()))
|
||||
}
|
||||
|
||||
func Test_Ctx_RenderWithLocalsAndBinding(t *testing.T) {
|
||||
|
@ -3027,11 +3046,12 @@ func Test_Ctx_RenderWithLocalsAndBinding(t *testing.T) {
|
|||
|
||||
c.Locals("Title", "This is a test.")
|
||||
defer app.ReleaseCtx(c)
|
||||
|
||||
err = c.Render("index.tmpl", Map{
|
||||
"Title": "Hello, World!",
|
||||
})
|
||||
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
utils.AssertEqual(t, "<h1>Hello, World!</h1>", string(c.Response().Body()))
|
||||
}
|
||||
|
||||
|
@ -3060,8 +3080,8 @@ func Benchmark_Ctx_RenderWithLocalsAndBinding(b *testing.B) {
|
|||
for n := 0; n < b.N; n++ {
|
||||
err = c.Render("template.tmpl", Map{})
|
||||
}
|
||||
|
||||
utils.AssertEqual(b, nil, err)
|
||||
|
||||
utils.AssertEqual(b, "<h1>Hello, World! Test</h1>", string(c.Response().Body()))
|
||||
}
|
||||
|
||||
|
@ -3083,8 +3103,8 @@ func Benchmark_Ctx_RedirectToRoute(b *testing.B) {
|
|||
"name": "fiber",
|
||||
})
|
||||
}
|
||||
|
||||
utils.AssertEqual(b, nil, err)
|
||||
|
||||
utils.AssertEqual(b, 302, c.Response().StatusCode())
|
||||
utils.AssertEqual(b, "/user/fiber", string(c.Response().Header.Peek(HeaderLocation)))
|
||||
}
|
||||
|
@ -3108,8 +3128,8 @@ func Benchmark_Ctx_RedirectToRouteWithQueries(b *testing.B) {
|
|||
"queries": map[string]string{"a": "a", "b": "b"},
|
||||
})
|
||||
}
|
||||
|
||||
utils.AssertEqual(b, nil, err)
|
||||
|
||||
utils.AssertEqual(b, 302, c.Response().StatusCode())
|
||||
// analysis of query parameters with url parsing, since a map pass is always randomly ordered
|
||||
location, err := url.Parse(string(c.Response().Header.Peek(HeaderLocation)))
|
||||
|
@ -3138,8 +3158,8 @@ func Benchmark_Ctx_RenderLocals(b *testing.B) {
|
|||
for n := 0; n < b.N; n++ {
|
||||
err = c.Render("index.tmpl", Map{})
|
||||
}
|
||||
|
||||
utils.AssertEqual(b, nil, err)
|
||||
|
||||
utils.AssertEqual(b, "<h1>Hello, World!</h1>", string(c.Response().Body()))
|
||||
}
|
||||
|
||||
|
@ -3164,8 +3184,8 @@ func Benchmark_Ctx_RenderBind(b *testing.B) {
|
|||
for n := 0; n < b.N; n++ {
|
||||
err = c.Render("index.tmpl", Map{})
|
||||
}
|
||||
|
||||
utils.AssertEqual(b, nil, err)
|
||||
|
||||
utils.AssertEqual(b, "<h1>Hello, World!</h1>", string(c.Response().Body()))
|
||||
}
|
||||
|
||||
|
@ -3191,8 +3211,7 @@ func Test_Ctx_RestartRouting(t *testing.T) {
|
|||
func Test_Ctx_RestartRoutingWithChangedPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := New()
|
||||
executedOldHandler := false
|
||||
executedNewHandler := false
|
||||
var executedOldHandler, executedNewHandler bool
|
||||
|
||||
app.Get("/old", func(c *Ctx) error {
|
||||
c.Path("/new")
|
||||
|
@ -3242,10 +3261,18 @@ type testTemplateEngine struct {
|
|||
|
||||
func (t *testTemplateEngine) Render(w io.Writer, name string, bind interface{}, layout ...string) error {
|
||||
if len(layout) == 0 {
|
||||
return t.templates.ExecuteTemplate(w, name, bind)
|
||||
if err := t.templates.ExecuteTemplate(w, name, bind); err != nil {
|
||||
return fmt.Errorf("failed to execute template without layout: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
_ = t.templates.ExecuteTemplate(w, name, bind)
|
||||
return t.templates.ExecuteTemplate(w, layout[0], bind)
|
||||
if err := t.templates.ExecuteTemplate(w, name, bind); err != nil {
|
||||
return fmt.Errorf("failed to execute template: %w", err)
|
||||
}
|
||||
if err := t.templates.ExecuteTemplate(w, layout[0], bind); err != nil {
|
||||
return fmt.Errorf("failed to execute template with layout: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *testTemplateEngine) Load() error {
|
||||
|
@ -3324,7 +3351,6 @@ func Benchmark_Ctx_Get_Location_From_Route(b *testing.B) {
|
|||
}
|
||||
utils.AssertEqual(b, "/user/fiber", location)
|
||||
utils.AssertEqual(b, nil, err)
|
||||
|
||||
}
|
||||
|
||||
// go test -run Test_Ctx_Get_Location_From_Route_name
|
||||
|
@ -3407,11 +3433,11 @@ func Test_Ctx_Get_Location_From_Route_name_Optional_greedy_one_param(t *testing.
|
|||
|
||||
type errorTemplateEngine struct{}
|
||||
|
||||
func (t errorTemplateEngine) Render(w io.Writer, name string, bind interface{}, layout ...string) error {
|
||||
func (errorTemplateEngine) Render(_ io.Writer, _ string, _ interface{}, _ ...string) error {
|
||||
return errors.New("errorTemplateEngine")
|
||||
}
|
||||
|
||||
func (t errorTemplateEngine) Load() error { return nil }
|
||||
func (errorTemplateEngine) Load() error { return nil }
|
||||
|
||||
// go test -run Test_Ctx_Render_Engine_Error
|
||||
func Test_Ctx_Render_Engine_Error(t *testing.T) {
|
||||
|
@ -3429,7 +3455,10 @@ func Test_Ctx_Render_Go_Template(t *testing.T) {
|
|||
t.Parallel()
|
||||
file, err := os.CreateTemp(os.TempDir(), "fiber")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
defer os.Remove(file.Name())
|
||||
defer func() {
|
||||
err := os.Remove(file.Name())
|
||||
utils.AssertEqual(t, nil, err)
|
||||
}()
|
||||
|
||||
_, err = file.Write([]byte("template"))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
@ -3821,7 +3850,7 @@ func Test_Ctx_QueryParser(t *testing.T) {
|
|||
}
|
||||
rq := new(RequiredQuery)
|
||||
c.Request().URI().SetQueryString("")
|
||||
utils.AssertEqual(t, "name is empty", c.QueryParser(rq).Error())
|
||||
utils.AssertEqual(t, "failed to decode: name is empty", c.QueryParser(rq).Error())
|
||||
|
||||
type ArrayQuery struct {
|
||||
Data []string
|
||||
|
@ -3837,7 +3866,7 @@ func Test_Ctx_QueryParser_WithSetParserDecoder(t *testing.T) {
|
|||
t.Parallel()
|
||||
type NonRFCTime time.Time
|
||||
|
||||
NonRFCConverter := func(value string) reflect.Value {
|
||||
nonRFCConverter := func(value string) reflect.Value {
|
||||
if v, err := time.Parse("2006-01-02", value); err == nil {
|
||||
return reflect.ValueOf(v)
|
||||
}
|
||||
|
@ -3846,7 +3875,7 @@ func Test_Ctx_QueryParser_WithSetParserDecoder(t *testing.T) {
|
|||
|
||||
nonRFCTime := ParserType{
|
||||
Customtype: NonRFCTime{},
|
||||
Converter: NonRFCConverter,
|
||||
Converter: nonRFCConverter,
|
||||
}
|
||||
|
||||
SetParserDecoder(ParserConfig{
|
||||
|
@ -3872,7 +3901,6 @@ func Test_Ctx_QueryParser_WithSetParserDecoder(t *testing.T) {
|
|||
|
||||
c.Request().URI().SetQueryString("date=2021-04-10&title=CustomDateTest&Body=October")
|
||||
utils.AssertEqual(t, nil, c.QueryParser(q))
|
||||
fmt.Println(q.Date, "q.Date")
|
||||
utils.AssertEqual(t, "CustomDateTest", q.Title)
|
||||
date := fmt.Sprintf("%v", q.Date)
|
||||
utils.AssertEqual(t, "{0 63753609600 <nil>}", date)
|
||||
|
@ -3907,7 +3935,7 @@ func Test_Ctx_QueryParser_Schema(t *testing.T) {
|
|||
|
||||
c.Request().URI().SetQueryString("namex=tom&nested.age=10")
|
||||
q = new(Query1)
|
||||
utils.AssertEqual(t, "name is empty", c.QueryParser(q).Error())
|
||||
utils.AssertEqual(t, "failed to decode: name is empty", c.QueryParser(q).Error())
|
||||
|
||||
c.Request().URI().SetQueryString("name=tom&nested.agex=10")
|
||||
q = new(Query1)
|
||||
|
@ -3915,7 +3943,7 @@ func Test_Ctx_QueryParser_Schema(t *testing.T) {
|
|||
|
||||
c.Request().URI().SetQueryString("name=tom&test.age=10")
|
||||
q = new(Query1)
|
||||
utils.AssertEqual(t, "nested is empty", c.QueryParser(q).Error())
|
||||
utils.AssertEqual(t, "failed to decode: nested is empty", c.QueryParser(q).Error())
|
||||
|
||||
type Query2 struct {
|
||||
Name string `query:"name"`
|
||||
|
@ -3933,11 +3961,11 @@ func Test_Ctx_QueryParser_Schema(t *testing.T) {
|
|||
|
||||
c.Request().URI().SetQueryString("nested.agex=10")
|
||||
q2 = new(Query2)
|
||||
utils.AssertEqual(t, "nested.age is empty", c.QueryParser(q2).Error())
|
||||
utils.AssertEqual(t, "failed to decode: nested.age is empty", c.QueryParser(q2).Error())
|
||||
|
||||
c.Request().URI().SetQueryString("nested.agex=10")
|
||||
q2 = new(Query2)
|
||||
utils.AssertEqual(t, "nested.age is empty", c.QueryParser(q2).Error())
|
||||
utils.AssertEqual(t, "failed to decode: nested.age is empty", c.QueryParser(q2).Error())
|
||||
|
||||
type Node struct {
|
||||
Value int `query:"val,required"`
|
||||
|
@ -3951,7 +3979,7 @@ func Test_Ctx_QueryParser_Schema(t *testing.T) {
|
|||
|
||||
c.Request().URI().SetQueryString("next.val=2")
|
||||
n = new(Node)
|
||||
utils.AssertEqual(t, "val is empty", c.QueryParser(n).Error())
|
||||
utils.AssertEqual(t, "failed to decode: val is empty", c.QueryParser(n).Error())
|
||||
|
||||
c.Request().URI().SetQueryString("val=3&next.value=2")
|
||||
n = new(Node)
|
||||
|
@ -4057,7 +4085,7 @@ func Test_Ctx_ReqHeaderParser(t *testing.T) {
|
|||
}
|
||||
rh := new(RequiredHeader)
|
||||
c.Request().Header.Del("name")
|
||||
utils.AssertEqual(t, "name is empty", c.ReqHeaderParser(rh).Error())
|
||||
utils.AssertEqual(t, "failed to decode: name is empty", c.ReqHeaderParser(rh).Error())
|
||||
}
|
||||
|
||||
// go test -run Test_Ctx_ReqHeaderParser_WithSetParserDecoder -v
|
||||
|
@ -4065,7 +4093,7 @@ func Test_Ctx_ReqHeaderParser_WithSetParserDecoder(t *testing.T) {
|
|||
t.Parallel()
|
||||
type NonRFCTime time.Time
|
||||
|
||||
NonRFCConverter := func(value string) reflect.Value {
|
||||
nonRFCConverter := func(value string) reflect.Value {
|
||||
if v, err := time.Parse("2006-01-02", value); err == nil {
|
||||
return reflect.ValueOf(v)
|
||||
}
|
||||
|
@ -4074,7 +4102,7 @@ func Test_Ctx_ReqHeaderParser_WithSetParserDecoder(t *testing.T) {
|
|||
|
||||
nonRFCTime := ParserType{
|
||||
Customtype: NonRFCTime{},
|
||||
Converter: NonRFCConverter,
|
||||
Converter: nonRFCConverter,
|
||||
}
|
||||
|
||||
SetParserDecoder(ParserConfig{
|
||||
|
@ -4103,7 +4131,6 @@ func Test_Ctx_ReqHeaderParser_WithSetParserDecoder(t *testing.T) {
|
|||
c.Request().Header.Add("Body", "October")
|
||||
|
||||
utils.AssertEqual(t, nil, c.ReqHeaderParser(r))
|
||||
fmt.Println(r.Date, "q.Date")
|
||||
utils.AssertEqual(t, "CustomDateTest", r.Title)
|
||||
date := fmt.Sprintf("%v", r.Date)
|
||||
utils.AssertEqual(t, "{0 63753609600 <nil>}", date)
|
||||
|
@ -4140,7 +4167,7 @@ func Test_Ctx_ReqHeaderParser_Schema(t *testing.T) {
|
|||
|
||||
c.Request().Header.Del("Name")
|
||||
q = new(Header1)
|
||||
utils.AssertEqual(t, "Name is empty", c.ReqHeaderParser(q).Error())
|
||||
utils.AssertEqual(t, "failed to decode: Name is empty", c.ReqHeaderParser(q).Error())
|
||||
|
||||
c.Request().Header.Add("Name", "tom")
|
||||
c.Request().Header.Del("Nested.Age")
|
||||
|
@ -4150,7 +4177,7 @@ func Test_Ctx_ReqHeaderParser_Schema(t *testing.T) {
|
|||
|
||||
c.Request().Header.Del("Nested.Agex")
|
||||
q = new(Header1)
|
||||
utils.AssertEqual(t, "Nested is empty", c.ReqHeaderParser(q).Error())
|
||||
utils.AssertEqual(t, "failed to decode: Nested is empty", c.ReqHeaderParser(q).Error())
|
||||
|
||||
c.Request().Header.Del("Nested.Agex")
|
||||
c.Request().Header.Del("Name")
|
||||
|
@ -4176,7 +4203,7 @@ func Test_Ctx_ReqHeaderParser_Schema(t *testing.T) {
|
|||
c.Request().Header.Del("Nested.Age")
|
||||
c.Request().Header.Add("Nested.Agex", "10")
|
||||
h2 = new(Header2)
|
||||
utils.AssertEqual(t, "Nested.age is empty", c.ReqHeaderParser(h2).Error())
|
||||
utils.AssertEqual(t, "failed to decode: Nested.age is empty", c.ReqHeaderParser(h2).Error())
|
||||
|
||||
type Node struct {
|
||||
Value int `reqHeader:"Val,required"`
|
||||
|
@ -4191,7 +4218,7 @@ func Test_Ctx_ReqHeaderParser_Schema(t *testing.T) {
|
|||
|
||||
c.Request().Header.Del("Val")
|
||||
n = new(Node)
|
||||
utils.AssertEqual(t, "Val is empty", c.ReqHeaderParser(n).Error())
|
||||
utils.AssertEqual(t, "failed to decode: Val is empty", c.ReqHeaderParser(n).Error())
|
||||
|
||||
c.Request().Header.Add("Val", "3")
|
||||
c.Request().Header.Del("Next.Val")
|
||||
|
@ -4628,8 +4655,9 @@ func Test_Ctx_RepeatParserWithSameStruct(t *testing.T) {
|
|||
|
||||
var gzipJSON bytes.Buffer
|
||||
w := gzip.NewWriter(&gzipJSON)
|
||||
_, _ = w.Write([]byte(`{"body_param":"body_param"}`))
|
||||
_ = w.Close()
|
||||
_, _ = w.Write([]byte(`{"body_param":"body_param"}`)) //nolint:errcheck // This will never fail
|
||||
err := w.Close()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
c.Request().Header.SetContentType(MIMEApplicationJSON)
|
||||
c.Request().Header.Set(HeaderContentEncoding, "gzip")
|
||||
c.Request().SetBody(gzipJSON.Bytes())
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
package fiber
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
jerrors "encoding/json"
|
||||
|
||||
"github.com/gofiber/fiber/v2/internal/schema"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
@ -36,42 +35,42 @@ func TestMultiError(t *testing.T) {
|
|||
|
||||
func TestInvalidUnmarshalError(t *testing.T) {
|
||||
t.Parallel()
|
||||
var e *jerrors.InvalidUnmarshalError
|
||||
var e *json.InvalidUnmarshalError
|
||||
ok := errors.As(&InvalidUnmarshalError{}, &e)
|
||||
utils.AssertEqual(t, true, ok)
|
||||
}
|
||||
|
||||
func TestMarshalerError(t *testing.T) {
|
||||
t.Parallel()
|
||||
var e *jerrors.MarshalerError
|
||||
var e *json.MarshalerError
|
||||
ok := errors.As(&MarshalerError{}, &e)
|
||||
utils.AssertEqual(t, true, ok)
|
||||
}
|
||||
|
||||
func TestSyntaxError(t *testing.T) {
|
||||
t.Parallel()
|
||||
var e *jerrors.SyntaxError
|
||||
var e *json.SyntaxError
|
||||
ok := errors.As(&SyntaxError{}, &e)
|
||||
utils.AssertEqual(t, true, ok)
|
||||
}
|
||||
|
||||
func TestUnmarshalTypeError(t *testing.T) {
|
||||
t.Parallel()
|
||||
var e *jerrors.UnmarshalTypeError
|
||||
var e *json.UnmarshalTypeError
|
||||
ok := errors.As(&UnmarshalTypeError{}, &e)
|
||||
utils.AssertEqual(t, true, ok)
|
||||
}
|
||||
|
||||
func TestUnsupportedTypeError(t *testing.T) {
|
||||
t.Parallel()
|
||||
var e *jerrors.UnsupportedTypeError
|
||||
var e *json.UnsupportedTypeError
|
||||
ok := errors.As(&UnsupportedTypeError{}, &e)
|
||||
utils.AssertEqual(t, true, ok)
|
||||
}
|
||||
|
||||
func TestUnsupportedValeError(t *testing.T) {
|
||||
t.Parallel()
|
||||
var e *jerrors.UnsupportedValueError
|
||||
var e *json.UnsupportedValueError
|
||||
ok := errors.As(&UnsupportedValueError{}, &e)
|
||||
utils.AssertEqual(t, true, ok)
|
||||
}
|
||||
|
|
1
group.go
1
group.go
|
@ -168,7 +168,6 @@ func (grp *Group) Group(prefix string, handlers ...Handler) Router {
|
|||
}
|
||||
|
||||
return newGrp
|
||||
|
||||
}
|
||||
|
||||
// Route is used to define routes with a common prefix inside the common function.
|
||||
|
|
78
helpers.go
78
helpers.go
|
@ -20,13 +20,13 @@ import (
|
|||
"unsafe"
|
||||
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/bytebufferpool"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
/* #nosec */
|
||||
// getTlsConfig returns a net listener's tls config
|
||||
func getTlsConfig(ln net.Listener) *tls.Config {
|
||||
// getTLSConfig returns a net listener's tls config
|
||||
func getTLSConfig(ln net.Listener) *tls.Config {
|
||||
// Get listener type
|
||||
pointer := reflect.ValueOf(ln)
|
||||
|
||||
|
@ -37,12 +37,16 @@ func getTlsConfig(ln net.Listener) *tls.Config {
|
|||
// Get private field from value
|
||||
if field := val.FieldByName("config"); field.Type() != nil {
|
||||
// Copy value from pointer field (unsafe)
|
||||
newval := reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())) // #nosec G103
|
||||
newval := reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())) //nolint:gosec // Probably the only way to extract the *tls.Config from a net.Listener. TODO: Verify there really is no easier way without using unsafe.
|
||||
if newval.Type() != nil {
|
||||
// Get element from pointer
|
||||
if elem := newval.Elem(); elem.Type() != nil {
|
||||
// Cast value to *tls.Config
|
||||
return elem.Interface().(*tls.Config)
|
||||
c, ok := elem.Interface().(*tls.Config)
|
||||
if !ok {
|
||||
panic(fmt.Errorf("failed to type-assert to *tls.Config"))
|
||||
}
|
||||
return c
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -53,19 +57,21 @@ func getTlsConfig(ln net.Listener) *tls.Config {
|
|||
}
|
||||
|
||||
// readContent opens a named file and read content from it
|
||||
func readContent(rf io.ReaderFrom, name string) (n int64, err error) {
|
||||
func readContent(rf io.ReaderFrom, name string) (int64, error) {
|
||||
// Read file
|
||||
f, err := os.Open(filepath.Clean(name))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return 0, fmt.Errorf("failed to open: %w", err)
|
||||
}
|
||||
// #nosec G307
|
||||
defer func() {
|
||||
if err = f.Close(); err != nil {
|
||||
log.Printf("Error closing file: %s\n", err)
|
||||
}
|
||||
}()
|
||||
return rf.ReadFrom(f)
|
||||
if n, err := rf.ReadFrom(f); err != nil {
|
||||
return n, fmt.Errorf("failed to read: %w", err)
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// quoteString escape special characters in a given string
|
||||
|
@ -78,7 +84,8 @@ func (app *App) quoteString(raw string) string {
|
|||
}
|
||||
|
||||
// Scan stack if other methods match the request
|
||||
func (app *App) methodExist(ctx *Ctx) (exist bool) {
|
||||
func (app *App) methodExist(ctx *Ctx) bool {
|
||||
var exists bool
|
||||
methods := app.config.RequestMethods
|
||||
for i := 0; i < len(methods); i++ {
|
||||
// Skip original method
|
||||
|
@ -108,7 +115,7 @@ func (app *App) methodExist(ctx *Ctx) (exist bool) {
|
|||
// No match, next route
|
||||
if match {
|
||||
// We matched
|
||||
exist = true
|
||||
exists = true
|
||||
// Add method to Allow header
|
||||
ctx.Append(HeaderAllow, methods[i])
|
||||
// Break stack loop
|
||||
|
@ -116,7 +123,7 @@ func (app *App) methodExist(ctx *Ctx) (exist bool) {
|
|||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
return exists
|
||||
}
|
||||
|
||||
// uniqueRouteStack drop all not unique routes from the slice
|
||||
|
@ -146,7 +153,7 @@ func defaultString(value string, defaultValue []string) string {
|
|||
const normalizedHeaderETag = "Etag"
|
||||
|
||||
// Generate and set ETag header to response
|
||||
func setETag(c *Ctx, weak bool) {
|
||||
func setETag(c *Ctx, weak bool) { //nolint: revive // Accepting a bool param is fine here
|
||||
// Don't generate ETags for invalid responses
|
||||
if c.fasthttp.Response.StatusCode() != StatusOK {
|
||||
return
|
||||
|
@ -160,7 +167,8 @@ func setETag(c *Ctx, weak bool) {
|
|||
clientEtag := c.Get(HeaderIfNoneMatch)
|
||||
|
||||
// Generate ETag for response
|
||||
crc32q := crc32.MakeTable(0xD5828281)
|
||||
const pol = 0xD5828281
|
||||
crc32q := crc32.MakeTable(pol)
|
||||
etag := fmt.Sprintf("\"%d-%v\"", len(body), crc32.Checksum(body, crc32q))
|
||||
|
||||
// Enable weak tag
|
||||
|
@ -173,7 +181,9 @@ func setETag(c *Ctx, weak bool) {
|
|||
// Check if server's ETag is weak
|
||||
if clientEtag[2:] == etag || clientEtag[2:] == etag[2:] {
|
||||
// W/1 == 1 || W/1 == W/1
|
||||
_ = c.SendStatus(StatusNotModified)
|
||||
if err := c.SendStatus(StatusNotModified); err != nil {
|
||||
log.Printf("setETag: failed to SendStatus: %v\n", err)
|
||||
}
|
||||
c.fasthttp.ResetBody()
|
||||
return
|
||||
}
|
||||
|
@ -183,7 +193,9 @@ func setETag(c *Ctx, weak bool) {
|
|||
}
|
||||
if strings.Contains(clientEtag, etag) {
|
||||
// 1 == 1
|
||||
_ = c.SendStatus(StatusNotModified)
|
||||
if err := c.SendStatus(StatusNotModified); err != nil {
|
||||
log.Printf("setETag: failed to SendStatus: %v\n", err)
|
||||
}
|
||||
c.fasthttp.ResetBody()
|
||||
return
|
||||
}
|
||||
|
@ -239,7 +251,7 @@ func getOffer(header string, offers ...string) string {
|
|||
return ""
|
||||
}
|
||||
|
||||
func matchEtag(s string, etag string) bool {
|
||||
func matchEtag(s, etag string) bool {
|
||||
if s == etag || s == "W/"+etag || "W/"+s == etag {
|
||||
return true
|
||||
}
|
||||
|
@ -254,12 +266,12 @@ func (app *App) isEtagStale(etag string, noneMatchBytes []byte) bool {
|
|||
// https://github.com/jshttp/fresh/blob/10e0471669dbbfbfd8de65bc6efac2ddd0bfa057/index.js#L110
|
||||
for i := range noneMatchBytes {
|
||||
switch noneMatchBytes[i] {
|
||||
case 0x20:
|
||||
case 0x20: //nolint:gomnd // This is a space (" ")
|
||||
if start == end {
|
||||
start = i + 1
|
||||
end = i + 1
|
||||
}
|
||||
case 0x2c:
|
||||
case 0x2c: //nolint:gomnd // This is a comma (",")
|
||||
if matchEtag(app.getString(noneMatchBytes[start:end]), etag) {
|
||||
return false
|
||||
}
|
||||
|
@ -273,7 +285,7 @@ func (app *App) isEtagStale(etag string, noneMatchBytes []byte) bool {
|
|||
return !matchEtag(app.getString(noneMatchBytes[start:end]), etag)
|
||||
}
|
||||
|
||||
func parseAddr(raw string) (host, port string) {
|
||||
func parseAddr(raw string) (string, string) { //nolint:revive // Returns (host, port)
|
||||
if i := strings.LastIndex(raw, ":"); i != -1 {
|
||||
return raw[:i], raw[i+1:]
|
||||
}
|
||||
|
@ -313,21 +325,21 @@ type testConn struct {
|
|||
w bytes.Buffer
|
||||
}
|
||||
|
||||
func (c *testConn) Read(b []byte) (int, error) { return c.r.Read(b) }
|
||||
func (c *testConn) Write(b []byte) (int, error) { return c.w.Write(b) }
|
||||
func (c *testConn) Close() error { return nil }
|
||||
func (c *testConn) Read(b []byte) (int, error) { return c.r.Read(b) } //nolint:wrapcheck // This must not be wrapped
|
||||
func (c *testConn) Write(b []byte) (int, error) { return c.w.Write(b) } //nolint:wrapcheck // This must not be wrapped
|
||||
func (*testConn) Close() error { return nil }
|
||||
|
||||
func (c *testConn) LocalAddr() net.Addr { return &net.TCPAddr{Port: 0, Zone: "", IP: net.IPv4zero} }
|
||||
func (c *testConn) RemoteAddr() net.Addr { return &net.TCPAddr{Port: 0, Zone: "", IP: net.IPv4zero} }
|
||||
func (c *testConn) SetDeadline(_ time.Time) error { return nil }
|
||||
func (c *testConn) SetReadDeadline(_ time.Time) error { return nil }
|
||||
func (c *testConn) SetWriteDeadline(_ time.Time) error { return nil }
|
||||
func (*testConn) LocalAddr() net.Addr { return &net.TCPAddr{Port: 0, Zone: "", IP: net.IPv4zero} }
|
||||
func (*testConn) RemoteAddr() net.Addr { return &net.TCPAddr{Port: 0, Zone: "", IP: net.IPv4zero} }
|
||||
func (*testConn) SetDeadline(_ time.Time) error { return nil }
|
||||
func (*testConn) SetReadDeadline(_ time.Time) error { return nil }
|
||||
func (*testConn) SetWriteDeadline(_ time.Time) error { return nil }
|
||||
|
||||
var getStringImmutable = func(b []byte) string {
|
||||
func getStringImmutable(b []byte) string {
|
||||
return string(b)
|
||||
}
|
||||
|
||||
var getBytesImmutable = func(s string) (b []byte) {
|
||||
func getBytesImmutable(s string) []byte {
|
||||
return []byte(s)
|
||||
}
|
||||
|
||||
|
@ -335,6 +347,7 @@ var getBytesImmutable = func(s string) (b []byte) {
|
|||
func (app *App) methodInt(s string) int {
|
||||
// For better performance
|
||||
if len(app.configured.RequestMethods) == 0 {
|
||||
//nolint:gomnd // TODO: Use iota instead
|
||||
switch s {
|
||||
case MethodGet:
|
||||
return 0
|
||||
|
@ -391,8 +404,7 @@ func IsMethodIdempotent(m string) bool {
|
|||
}
|
||||
|
||||
switch m {
|
||||
case MethodPut,
|
||||
MethodDelete:
|
||||
case MethodPut, MethodDelete:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
|
@ -714,7 +726,7 @@ const (
|
|||
ConstraintBool = "bool"
|
||||
ConstraintFloat = "float"
|
||||
ConstraintAlpha = "alpha"
|
||||
ConstraintGuid = "guid"
|
||||
ConstraintGuid = "guid" //nolint:revive,stylecheck // TODO: Rename to "ConstraintGUID" in v3
|
||||
ConstraintMinLen = "minLen"
|
||||
ConstraintMaxLen = "maxLen"
|
||||
ConstraintLen = "len"
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
|
|
30
hooks.go
30
hooks.go
|
@ -1,14 +1,20 @@
|
|||
package fiber
|
||||
|
||||
import (
|
||||
"log"
|
||||
)
|
||||
|
||||
// OnRouteHandler Handlers define a function to create hooks for Fiber.
|
||||
type OnRouteHandler = func(Route) error
|
||||
type OnNameHandler = OnRouteHandler
|
||||
type OnGroupHandler = func(Group) error
|
||||
type OnGroupNameHandler = OnGroupHandler
|
||||
type OnListenHandler = func() error
|
||||
type OnShutdownHandler = OnListenHandler
|
||||
type OnForkHandler = func(int) error
|
||||
type OnMountHandler = func(*App) error
|
||||
type (
|
||||
OnRouteHandler = func(Route) error
|
||||
OnNameHandler = OnRouteHandler
|
||||
OnGroupHandler = func(Group) error
|
||||
OnGroupNameHandler = OnGroupHandler
|
||||
OnListenHandler = func() error
|
||||
OnShutdownHandler = OnListenHandler
|
||||
OnForkHandler = func(int) error
|
||||
OnMountHandler = func(*App) error
|
||||
)
|
||||
|
||||
// Hooks is a struct to use it with App.
|
||||
type Hooks struct {
|
||||
|
@ -180,13 +186,17 @@ func (h *Hooks) executeOnListenHooks() error {
|
|||
|
||||
func (h *Hooks) executeOnShutdownHooks() {
|
||||
for _, v := range h.onShutdown {
|
||||
_ = v()
|
||||
if err := v(); err != nil {
|
||||
log.Printf("failed to call shutdown hook: %v\n", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hooks) executeOnForkHooks(pid int) {
|
||||
for _, v := range h.onFork {
|
||||
_ = v(pid)
|
||||
if err := v(pid); err != nil {
|
||||
log.Printf("failed to call fork hook: %v\n", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -7,10 +7,11 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/bytebufferpool"
|
||||
)
|
||||
|
||||
var testSimpleHandler = func(c *Ctx) error {
|
||||
func testSimpleHandler(c *Ctx) error {
|
||||
return c.SendString("simple")
|
||||
}
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ var defaultPool = sync.Pool{
|
|||
|
||||
// AcquireDict acquire new dict.
|
||||
func AcquireDict() *Dict {
|
||||
return defaultPool.Get().(*Dict) // nolint:forcetypeassert
|
||||
return defaultPool.Get().(*Dict)
|
||||
}
|
||||
|
||||
// ReleaseDict release dict.
|
||||
|
|
|
@ -366,7 +366,7 @@ func HostDev(combineWith ...string) string {
|
|||
// getSysctrlEnv sets LC_ALL=C in a list of env vars for use when running
|
||||
// sysctl commands (see DoSysctrl).
|
||||
func getSysctrlEnv(env []string) []string {
|
||||
foundLC := false
|
||||
var foundLC bool
|
||||
for i, line := range env {
|
||||
if strings.HasPrefix(line, "LC_ALL") {
|
||||
env[i] = "LC_ALL=C"
|
||||
|
|
|
@ -6,8 +6,7 @@ import (
|
|||
"github.com/gofiber/fiber/v2/internal/gopsutil/common"
|
||||
)
|
||||
|
||||
//lint:ignore U1000 we need this elsewhere
|
||||
var invoke common.Invoker = common.Invoke{} //nolint:all
|
||||
var invoke common.Invoker = common.Invoke{} //nolint:unused // We use this only for some OS'es
|
||||
|
||||
// Memory usage statistics. Total, Available and Used contain numbers of bytes
|
||||
// for human consumption.
|
||||
|
|
|
@ -86,7 +86,6 @@ func SwapMemory() (*SwapMemoryStat, error) {
|
|||
}
|
||||
|
||||
// Constants from vm/vm_param.h
|
||||
// nolint: golint
|
||||
const (
|
||||
XSWDEV_VERSION11 = 1
|
||||
XSWDEV_VERSION = 2
|
||||
|
|
|
@ -57,10 +57,12 @@ func fillFromMeminfoWithContext(ctx context.Context) (*VirtualMemoryStat, *Virtu
|
|||
lines, _ := common.ReadLines(filename)
|
||||
|
||||
// flag if MemAvailable is in /proc/meminfo (kernel 3.14+)
|
||||
memavail := false
|
||||
activeFile := false // "Active(file)" not available: 2.6.28 / Dec 2008
|
||||
inactiveFile := false // "Inactive(file)" not available: 2.6.28 / Dec 2008
|
||||
sReclaimable := false // "SReclaimable:" not available: 2.6.19 / Nov 2006
|
||||
var (
|
||||
memavail bool
|
||||
activeFile bool // "Active(file)" not available: 2.6.28 / Dec 2008
|
||||
inactiveFile bool // "Inactive(file)" not available: 2.6.28 / Dec 2008
|
||||
sReclaimable bool // "SReclaimable:" not available: 2.6.19 / Nov 2006
|
||||
)
|
||||
|
||||
ret := &VirtualMemoryStat{}
|
||||
retEx := &VirtualMemoryExStat{}
|
||||
|
|
|
@ -168,7 +168,7 @@ func rwArray(dst jsWriter, src *Reader) (n int, err error) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
comma := false
|
||||
var comma bool
|
||||
for i := uint32(0); i < sz; i++ {
|
||||
if comma {
|
||||
err = dst.WriteByte(',')
|
||||
|
|
|
@ -163,7 +163,7 @@ func (c *cache) createField(field reflect.StructField, parentAlias string) *fiel
|
|||
}
|
||||
// Check if the type is supported and don't cache it if not.
|
||||
// First let's get the basic type.
|
||||
isSlice, isStruct := false, false
|
||||
var isSlice, isStruct bool
|
||||
ft := field.Type
|
||||
m := isTextUnmarshaler(reflect.Zero(ft))
|
||||
if ft.Kind() == reflect.Ptr {
|
||||
|
|
|
@ -149,13 +149,13 @@ func Benchmark_Storage_Memory(b *testing.B) {
|
|||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
for _, key := range keys {
|
||||
d.Set(key, value, ttl)
|
||||
_ = d.Set(key, value, ttl)
|
||||
}
|
||||
for _, key := range keys {
|
||||
_, _ = d.Get(key)
|
||||
}
|
||||
for _, key := range keys {
|
||||
d.Delete(key)
|
||||
_ = d.Delete(key)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
|
@ -156,7 +156,6 @@ func (e *Engine) Load() error {
|
|||
name = strings.TrimSuffix(name, e.extension)
|
||||
// name = strings.Replace(name, e.extension, "", -1)
|
||||
// Read the file
|
||||
// #gosec G304
|
||||
buf, err := utils.ReadFile(path, e.fileSystem)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -21,7 +21,6 @@ func Walk(fs http.FileSystem, root string, walkFn filepath.WalkFunc) error {
|
|||
return walk(fs, root, info, walkFn)
|
||||
}
|
||||
|
||||
// #nosec G304
|
||||
// ReadFile returns the raw content of a file
|
||||
func ReadFile(path string, fs http.FileSystem) ([]byte, error) {
|
||||
if fs != nil {
|
||||
|
|
62
listen.go
62
listen.go
|
@ -9,6 +9,7 @@ import (
|
|||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
@ -31,7 +32,7 @@ func (app *App) Listener(ln net.Listener) error {
|
|||
|
||||
// Print startup message
|
||||
if !app.config.DisableStartupMessage {
|
||||
app.startupMessage(ln.Addr().String(), getTlsConfig(ln) != nil, "")
|
||||
app.startupMessage(ln.Addr().String(), getTLSConfig(ln) != nil, "")
|
||||
}
|
||||
|
||||
// Print routes
|
||||
|
@ -41,7 +42,7 @@ func (app *App) Listener(ln net.Listener) error {
|
|||
|
||||
// Prefork is not supported for custom listeners
|
||||
if app.config.Prefork {
|
||||
fmt.Println("[Warning] Prefork isn't supported for custom listeners.")
|
||||
log.Printf("[Warning] Prefork isn't supported for custom listeners.\n")
|
||||
}
|
||||
|
||||
// Start listening
|
||||
|
@ -61,7 +62,7 @@ func (app *App) Listen(addr string) error {
|
|||
// Setup listener
|
||||
ln, err := net.Listen(app.config.Network, addr)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to listen: %w", err)
|
||||
}
|
||||
|
||||
// prepare the server for the start
|
||||
|
@ -94,7 +95,7 @@ func (app *App) ListenTLS(addr, certFile, keyFile string) error {
|
|||
// Set TLS config with handler
|
||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %s", certFile, keyFile, err)
|
||||
return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %w", certFile, keyFile, err)
|
||||
}
|
||||
|
||||
tlsHandler := &TLSHandler{}
|
||||
|
@ -115,7 +116,7 @@ func (app *App) ListenTLS(addr, certFile, keyFile string) error {
|
|||
ln, err := net.Listen(app.config.Network, addr)
|
||||
ln = tls.NewListener(ln, config)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to listen: %w", err)
|
||||
}
|
||||
|
||||
// prepare the server for the start
|
||||
|
@ -150,12 +151,12 @@ func (app *App) ListenMutualTLS(addr, certFile, keyFile, clientCertFile string)
|
|||
|
||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %s", certFile, keyFile, err)
|
||||
return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %w", certFile, keyFile, err)
|
||||
}
|
||||
|
||||
clientCACert, err := os.ReadFile(filepath.Clean(clientCertFile))
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to read file: %w", err)
|
||||
}
|
||||
clientCertPool := x509.NewCertPool()
|
||||
clientCertPool.AppendCertsFromPEM(clientCACert)
|
||||
|
@ -179,7 +180,7 @@ func (app *App) ListenMutualTLS(addr, certFile, keyFile, clientCertFile string)
|
|||
// Setup listener
|
||||
ln, err := tls.Listen(app.config.Network, addr, config)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to listen: %w", err)
|
||||
}
|
||||
|
||||
// prepare the server for the start
|
||||
|
@ -203,7 +204,7 @@ func (app *App) ListenMutualTLS(addr, certFile, keyFile, clientCertFile string)
|
|||
}
|
||||
|
||||
// startupMessage prepares the startup message with the handler number, port, address and other information
|
||||
func (app *App) startupMessage(addr string, tls bool, pids string) {
|
||||
func (app *App) startupMessage(addr string, tls bool, pids string) { //nolint: revive // Accepting a bool param is fine here
|
||||
// ignore child processes
|
||||
if IsChild() {
|
||||
return
|
||||
|
@ -227,7 +228,8 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
|
|||
}
|
||||
|
||||
center := func(s string, width int) string {
|
||||
pad := strconv.Itoa((width - len(s)) / 2)
|
||||
const padDiv = 2
|
||||
pad := strconv.Itoa((width - len(s)) / padDiv)
|
||||
str := fmt.Sprintf("%"+pad+"s", " ")
|
||||
str += s
|
||||
str += fmt.Sprintf("%"+pad+"s", " ")
|
||||
|
@ -238,7 +240,8 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
|
|||
}
|
||||
|
||||
centerValue := func(s string, width int) string {
|
||||
pad := strconv.Itoa((width - runewidth.StringWidth(s)) / 2)
|
||||
const padDiv = 2
|
||||
pad := strconv.Itoa((width - runewidth.StringWidth(s)) / padDiv)
|
||||
str := fmt.Sprintf("%"+pad+"s", " ")
|
||||
str += fmt.Sprintf("%s%s%s", colors.Cyan, s, colors.Black)
|
||||
str += fmt.Sprintf("%"+pad+"s", " ")
|
||||
|
@ -249,13 +252,13 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
|
|||
return str
|
||||
}
|
||||
|
||||
pad := func(s string, width int) (str string) {
|
||||
pad := func(s string, width int) string {
|
||||
toAdd := width - len(s)
|
||||
str += s
|
||||
str := s
|
||||
for i := 0; i < toAdd; i++ {
|
||||
str += " "
|
||||
}
|
||||
return
|
||||
return str
|
||||
}
|
||||
|
||||
host, port := parseAddr(addr)
|
||||
|
@ -267,9 +270,9 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
|
|||
}
|
||||
}
|
||||
|
||||
scheme := "http"
|
||||
scheme := schemeHTTP
|
||||
if tls {
|
||||
scheme = "https"
|
||||
scheme = schemeHTTPS
|
||||
}
|
||||
|
||||
isPrefork := "Disabled"
|
||||
|
@ -282,19 +285,18 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
|
|||
procs = "1"
|
||||
}
|
||||
|
||||
const lineLen = 49
|
||||
mainLogo := colors.Black + " ┌───────────────────────────────────────────────────┐\n"
|
||||
if app.config.AppName != "" {
|
||||
mainLogo += " │ " + centerValue(app.config.AppName, 49) + " │\n"
|
||||
mainLogo += " │ " + centerValue(app.config.AppName, lineLen) + " │\n"
|
||||
}
|
||||
mainLogo += " │ " + centerValue("Fiber v"+Version, 49) + " │\n"
|
||||
mainLogo += " │ " + centerValue("Fiber v"+Version, lineLen) + " │\n"
|
||||
|
||||
if host == "0.0.0.0" {
|
||||
mainLogo +=
|
||||
" │ " + center(fmt.Sprintf("%s://127.0.0.1:%s", scheme, port), 49) + " │\n" +
|
||||
" │ " + center(fmt.Sprintf("(bound on host 0.0.0.0 and port %s)", port), 49) + " │\n"
|
||||
mainLogo += " │ " + center(fmt.Sprintf("%s://127.0.0.1:%s", scheme, port), lineLen) + " │\n" +
|
||||
" │ " + center(fmt.Sprintf("(bound on host 0.0.0.0 and port %s)", port), lineLen) + " │\n"
|
||||
} else {
|
||||
mainLogo +=
|
||||
" │ " + center(fmt.Sprintf("%s://%s:%s", scheme, host, port), 49) + " │\n"
|
||||
mainLogo += " │ " + center(fmt.Sprintf("%s://%s:%s", scheme, host, port), lineLen) + " │\n"
|
||||
}
|
||||
|
||||
mainLogo += fmt.Sprintf(
|
||||
|
@ -303,8 +305,8 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
|
|||
" │ Prefork .%s PID ....%s │\n"+
|
||||
" └───────────────────────────────────────────────────┘"+
|
||||
colors.Reset,
|
||||
value(strconv.Itoa(int(app.handlersCount)), 14), value(procs, 12),
|
||||
value(isPrefork, 14), value(strconv.Itoa(os.Getpid()), 14),
|
||||
value(strconv.Itoa(int(app.handlersCount)), 14), value(procs, 12), //nolint:gomnd // Using random padding lengths is fine here
|
||||
value(isPrefork, 14), value(strconv.Itoa(os.Getpid()), 14), //nolint:gomnd // Using random padding lengths is fine here
|
||||
)
|
||||
|
||||
var childPidsLogo string
|
||||
|
@ -329,19 +331,21 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
|
|||
thisLine := "Child PIDs ... "
|
||||
var itemsOnThisLine []string
|
||||
|
||||
const maxLineLen = 49
|
||||
|
||||
addLine := func() {
|
||||
lines = append(lines,
|
||||
fmt.Sprintf(
|
||||
newLine,
|
||||
colors.Black,
|
||||
thisLine+colors.Cyan+pad(strings.Join(itemsOnThisLine, ", "), 49-len(thisLine)),
|
||||
thisLine+colors.Cyan+pad(strings.Join(itemsOnThisLine, ", "), maxLineLen-len(thisLine)),
|
||||
colors.Black,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
for _, pid := range pidSlice {
|
||||
if len(thisLine+strings.Join(append(itemsOnThisLine, pid), ", ")) > 49 {
|
||||
if len(thisLine+strings.Join(append(itemsOnThisLine, pid), ", ")) > maxLineLen {
|
||||
addLine()
|
||||
thisLine = ""
|
||||
itemsOnThisLine = []string{pid}
|
||||
|
@ -415,7 +419,7 @@ func (app *App) printRoutesMessage() {
|
|||
var routes []RouteMessage
|
||||
for _, routeStack := range app.stack {
|
||||
for _, route := range routeStack {
|
||||
var newRoute = RouteMessage{}
|
||||
var newRoute RouteMessage
|
||||
newRoute.name = route.Name
|
||||
newRoute.method = route.Method
|
||||
newRoute.path = route.Path
|
||||
|
@ -443,5 +447,5 @@ func (app *App) printRoutesMessage() {
|
|||
_, _ = fmt.Fprintf(w, "%s%s\t%s| %s%s\t%s| %s%s\t%s| %s%s\n", colors.Blue, route.method, colors.White, colors.Green, route.path, colors.White, colors.Cyan, route.name, colors.White, colors.Yellow, route.handlers)
|
||||
}
|
||||
|
||||
_ = w.Flush()
|
||||
_ = w.Flush() //nolint:errcheck // It is fine to ignore the error here
|
||||
}
|
||||
|
|
|
@ -7,7 +7,6 @@ package fiber
|
|||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
|
@ -17,6 +16,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp/fasthttputil"
|
||||
)
|
||||
|
||||
|
@ -125,8 +125,10 @@ func Test_App_Listener_TLS_Listener(t *testing.T) {
|
|||
if err != nil {
|
||||
utils.AssertEqual(t, nil, err)
|
||||
}
|
||||
//nolint:gosec // We're in a test so using old ciphers is fine
|
||||
config := &tls.Config{Certificates: []tls.Certificate{cer}}
|
||||
|
||||
//nolint:gosec // We're in a test so listening on all interfaces is fine
|
||||
ln, err := tls.Listen(NetworkTCP4, ":0", config)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
|
@ -182,7 +184,6 @@ func Test_App_Master_Process_Show_Startup_Message(t *testing.T) {
|
|||
New(Config{Prefork: true}).
|
||||
startupMessage(":3000", true, strings.Repeat(",11111,22222,33333,44444,55555,60000", 10))
|
||||
})
|
||||
fmt.Println(startupMessage)
|
||||
utils.AssertEqual(t, true, strings.Contains(startupMessage, "https://127.0.0.1:3000"))
|
||||
utils.AssertEqual(t, true, strings.Contains(startupMessage, "(bound on host 0.0.0.0 and port 3000)"))
|
||||
utils.AssertEqual(t, true, strings.Contains(startupMessage, "Child PIDs"))
|
||||
|
@ -196,7 +197,6 @@ func Test_App_Master_Process_Show_Startup_MessageWithAppName(t *testing.T) {
|
|||
startupMessage := captureOutput(func() {
|
||||
app.startupMessage(":3000", true, strings.Repeat(",11111,22222,33333,44444,55555,60000", 10))
|
||||
})
|
||||
fmt.Println(startupMessage)
|
||||
utils.AssertEqual(t, "Test App v1.0.1", app.Config().AppName)
|
||||
utils.AssertEqual(t, true, strings.Contains(startupMessage, app.Config().AppName))
|
||||
}
|
||||
|
@ -208,7 +208,6 @@ func Test_App_Master_Process_Show_Startup_MessageWithAppNameNonAscii(t *testing.
|
|||
startupMessage := captureOutput(func() {
|
||||
app.startupMessage(":3000", false, "")
|
||||
})
|
||||
fmt.Println(startupMessage)
|
||||
utils.AssertEqual(t, true, strings.Contains(startupMessage, "│ Serveur de vérification des données │"))
|
||||
}
|
||||
|
||||
|
@ -219,8 +218,7 @@ func Test_App_print_Route(t *testing.T) {
|
|||
printRoutesMessage := captureOutput(func() {
|
||||
app.printRoutesMessage()
|
||||
})
|
||||
fmt.Println(printRoutesMessage)
|
||||
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "GET"))
|
||||
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, MethodGet))
|
||||
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "/"))
|
||||
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "emptyHandler"))
|
||||
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "routeName"))
|
||||
|
@ -240,11 +238,11 @@ func Test_App_print_Route_with_group(t *testing.T) {
|
|||
app.printRoutesMessage()
|
||||
})
|
||||
|
||||
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "GET"))
|
||||
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, MethodGet))
|
||||
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "/"))
|
||||
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "emptyHandler"))
|
||||
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "/v1/test"))
|
||||
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "POST"))
|
||||
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, MethodPost))
|
||||
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "/v1/test/fiber"))
|
||||
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "PUT"))
|
||||
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "/v1/test/fiber/*"))
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
package basicauth
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
b64 "encoding/base64"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
|
@ -23,7 +23,7 @@ func Test_BasicAuth_Next(t *testing.T) {
|
|||
},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
@ -39,6 +39,7 @@ func Test_Middleware_BasicAuth(t *testing.T) {
|
|||
},
|
||||
}))
|
||||
|
||||
//nolint:forcetypeassert,errcheck // TODO: Do not force-type assert
|
||||
app.Get("/testauth", func(c *fiber.Ctx) error {
|
||||
username := c.Locals("username").(string)
|
||||
password := c.Locals("password").(string)
|
||||
|
@ -74,9 +75,9 @@ func Test_Middleware_BasicAuth(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
// Base64 encode credentials for http auth header
|
||||
creds := b64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", tt.username, tt.password)))
|
||||
creds := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", tt.username, tt.password)))
|
||||
|
||||
req := httptest.NewRequest("GET", "/testauth", nil)
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/testauth", nil)
|
||||
req.Header.Add("Authorization", "Basic "+creds)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
@ -108,7 +109,7 @@ func Benchmark_Middleware_BasicAuth(b *testing.B) {
|
|||
h := app.Handler()
|
||||
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
fctx.Request.Header.SetMethod("GET")
|
||||
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
fctx.Request.SetRequestURI("/")
|
||||
fctx.Request.Header.Set(fiber.HeaderAuthorization, "basic am9objpkb2U=") // john:doe
|
||||
|
||||
|
|
|
@ -53,6 +53,8 @@ type Config struct {
|
|||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
//
|
||||
//nolint:gochecknoglobals // Using a global var is fine here
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
Users: map[string]string{},
|
||||
|
|
|
@ -34,6 +34,7 @@ const (
|
|||
noStore = "no-store"
|
||||
)
|
||||
|
||||
//nolint:gochecknoglobals // TODO: Do not use a global var here
|
||||
var ignoreHeaders = map[string]interface{}{
|
||||
"Connection": nil,
|
||||
"Keep-Alive": nil,
|
||||
|
@ -43,8 +44,8 @@ var ignoreHeaders = map[string]interface{}{
|
|||
"Trailers": nil,
|
||||
"Transfer-Encoding": nil,
|
||||
"Upgrade": nil,
|
||||
"Content-Type": nil, // already stored explicitely by the cache manager
|
||||
"Content-Encoding": nil, // already stored explicitely by the cache manager
|
||||
"Content-Type": nil, // already stored explicitly by the cache manager
|
||||
"Content-Encoding": nil, // already stored explicitly by the cache manager
|
||||
}
|
||||
|
||||
// New creates a new middleware handler
|
||||
|
@ -69,7 +70,7 @@ func New(config ...Config) fiber.Handler {
|
|||
// Create indexed heap for tracking expirations ( see heap.go )
|
||||
heap := &indexedHeap{}
|
||||
// count stored bytes (sizes of response bodies)
|
||||
var storedBytes uint = 0
|
||||
var storedBytes uint
|
||||
|
||||
// Update timestamp in the configured interval
|
||||
go func() {
|
||||
|
@ -81,10 +82,10 @@ func New(config ...Config) fiber.Handler {
|
|||
|
||||
// Delete key from both manager and storage
|
||||
deleteKey := func(dkey string) {
|
||||
manager.delete(dkey)
|
||||
manager.del(dkey)
|
||||
// External storage saves body data with different key
|
||||
if cfg.Storage != nil {
|
||||
manager.delete(dkey + "_body")
|
||||
manager.del(dkey + "_body")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -205,7 +206,7 @@ func New(config ...Config) fiber.Handler {
|
|||
if cfg.StoreResponseHeaders {
|
||||
e.headers = make(map[string][]byte)
|
||||
c.Response().Header.VisitAll(
|
||||
func(key []byte, value []byte) {
|
||||
func(key, value []byte) {
|
||||
// create real copy
|
||||
keyS := string(key)
|
||||
if _, ok := ignoreHeaders[keyS]; !ok {
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strconv"
|
||||
|
@ -18,6 +17,7 @@ import (
|
|||
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
||||
"github.com/gofiber/fiber/v2/middleware/etag"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
|
@ -35,10 +35,10 @@ func Test_Cache_CacheControl(t *testing.T) {
|
|||
return c.SendString("Hello, World!")
|
||||
})
|
||||
|
||||
_, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "public, max-age=10", resp.Header.Get(fiber.HeaderCacheControl))
|
||||
}
|
||||
|
@ -53,7 +53,7 @@ func Test_Cache_Expired(t *testing.T) {
|
|||
return c.SendString(fmt.Sprintf("%d", time.Now().UnixNano()))
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
@ -61,7 +61,7 @@ func Test_Cache_Expired(t *testing.T) {
|
|||
// Sleep until the cache is expired
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
respCached, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
bodyCached, err := io.ReadAll(respCached.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
@ -71,7 +71,7 @@ func Test_Cache_Expired(t *testing.T) {
|
|||
}
|
||||
|
||||
// Next response should be also cached
|
||||
respCachedNextRound, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
respCachedNextRound, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
bodyCachedNextRound, err := io.ReadAll(respCachedNextRound.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
@ -92,11 +92,11 @@ func Test_Cache(t *testing.T) {
|
|||
return c.SendString(now)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
cachedReq := httptest.NewRequest("GET", "/", nil)
|
||||
cachedReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
cachedResp, err := app.Test(cachedReq)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
|
@ -120,31 +120,31 @@ func Test_Cache_WithNoCacheRequestDirective(t *testing.T) {
|
|||
})
|
||||
|
||||
// Request id = 1
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
resp, err := app.Test(req)
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache"))
|
||||
utils.AssertEqual(t, []byte("1"), body)
|
||||
// Response cached, entry id = 1
|
||||
|
||||
// Request id = 2 without Cache-Control: no-cache
|
||||
cachedReq := httptest.NewRequest("GET", "/?id=2", nil)
|
||||
cachedReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
|
||||
cachedResp, err := app.Test(cachedReq)
|
||||
defer cachedResp.Body.Close()
|
||||
cachedBody, _ := io.ReadAll(cachedResp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
cachedBody, err := io.ReadAll(cachedResp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, cacheHit, cachedResp.Header.Get("X-Cache"))
|
||||
utils.AssertEqual(t, []byte("1"), cachedBody)
|
||||
// Response not cached, returns cached response, entry id = 1
|
||||
|
||||
// Request id = 2 with Cache-Control: no-cache
|
||||
noCacheReq := httptest.NewRequest("GET", "/?id=2", nil)
|
||||
noCacheReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
|
||||
noCacheReq.Header.Set(fiber.HeaderCacheControl, noCache)
|
||||
noCacheResp, err := app.Test(noCacheReq)
|
||||
defer noCacheResp.Body.Close()
|
||||
noCacheBody, _ := io.ReadAll(noCacheResp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
noCacheBody, err := io.ReadAll(noCacheResp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, cacheMiss, noCacheResp.Header.Get("X-Cache"))
|
||||
utils.AssertEqual(t, []byte("2"), noCacheBody)
|
||||
|
@ -152,21 +152,21 @@ func Test_Cache_WithNoCacheRequestDirective(t *testing.T) {
|
|||
|
||||
/* Check Test_Cache_WithETagAndNoCacheRequestDirective */
|
||||
// Request id = 2 with Cache-Control: no-cache again
|
||||
noCacheReq1 := httptest.NewRequest("GET", "/?id=2", nil)
|
||||
noCacheReq1 := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
|
||||
noCacheReq1.Header.Set(fiber.HeaderCacheControl, noCache)
|
||||
noCacheResp1, err := app.Test(noCacheReq1)
|
||||
defer noCacheResp1.Body.Close()
|
||||
noCacheBody1, _ := io.ReadAll(noCacheResp1.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
noCacheBody1, err := io.ReadAll(noCacheResp1.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, cacheMiss, noCacheResp1.Header.Get("X-Cache"))
|
||||
utils.AssertEqual(t, []byte("2"), noCacheBody1)
|
||||
// Response cached, returns updated response, entry = 2
|
||||
|
||||
// Request id = 1 without Cache-Control: no-cache
|
||||
cachedReq1 := httptest.NewRequest("GET", "/", nil)
|
||||
cachedReq1 := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
cachedResp1, err := app.Test(cachedReq1)
|
||||
defer cachedResp1.Body.Close()
|
||||
cachedBody1, _ := io.ReadAll(cachedResp1.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
cachedBody1, err := io.ReadAll(cachedResp1.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, cacheHit, cachedResp1.Header.Get("X-Cache"))
|
||||
utils.AssertEqual(t, []byte("2"), cachedBody1)
|
||||
|
@ -188,7 +188,7 @@ func Test_Cache_WithETagAndNoCacheRequestDirective(t *testing.T) {
|
|||
})
|
||||
|
||||
// Request id = 1
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache"))
|
||||
|
@ -199,7 +199,7 @@ func Test_Cache_WithETagAndNoCacheRequestDirective(t *testing.T) {
|
|||
etagToken := resp.Header.Get("Etag")
|
||||
|
||||
// Request id = 2 with ETag but without Cache-Control: no-cache
|
||||
cachedReq := httptest.NewRequest("GET", "/?id=2", nil)
|
||||
cachedReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
|
||||
cachedReq.Header.Set(fiber.HeaderIfNoneMatch, etagToken)
|
||||
cachedResp, err := app.Test(cachedReq)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
@ -208,7 +208,7 @@ func Test_Cache_WithETagAndNoCacheRequestDirective(t *testing.T) {
|
|||
// Response not cached, returns cached response, entry id = 1, status not modified
|
||||
|
||||
// Request id = 2 with ETag and Cache-Control: no-cache
|
||||
noCacheReq := httptest.NewRequest("GET", "/?id=2", nil)
|
||||
noCacheReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
|
||||
noCacheReq.Header.Set(fiber.HeaderCacheControl, noCache)
|
||||
noCacheReq.Header.Set(fiber.HeaderIfNoneMatch, etagToken)
|
||||
noCacheResp, err := app.Test(noCacheReq)
|
||||
|
@ -221,7 +221,7 @@ func Test_Cache_WithETagAndNoCacheRequestDirective(t *testing.T) {
|
|||
etagToken = noCacheResp.Header.Get("Etag")
|
||||
|
||||
// Request id = 2 with ETag and Cache-Control: no-cache again
|
||||
noCacheReq1 := httptest.NewRequest("GET", "/?id=2", nil)
|
||||
noCacheReq1 := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
|
||||
noCacheReq1.Header.Set(fiber.HeaderCacheControl, noCache)
|
||||
noCacheReq1.Header.Set(fiber.HeaderIfNoneMatch, etagToken)
|
||||
noCacheResp1, err := app.Test(noCacheReq1)
|
||||
|
@ -231,7 +231,7 @@ func Test_Cache_WithETagAndNoCacheRequestDirective(t *testing.T) {
|
|||
// Response cached, returns updated response, entry id = 2, status not modified
|
||||
|
||||
// Request id = 1 without ETag and Cache-Control: no-cache
|
||||
cachedReq1 := httptest.NewRequest("GET", "/", nil)
|
||||
cachedReq1 := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
cachedResp1, err := app.Test(cachedReq1)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, cacheHit, cachedResp1.Header.Get("X-Cache"))
|
||||
|
@ -251,11 +251,11 @@ func Test_Cache_WithNoStoreRequestDirective(t *testing.T) {
|
|||
})
|
||||
|
||||
// Request id = 2
|
||||
noStoreReq := httptest.NewRequest("GET", "/?id=2", nil)
|
||||
noStoreReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
|
||||
noStoreReq.Header.Set(fiber.HeaderCacheControl, noStore)
|
||||
noStoreResp, err := app.Test(noStoreReq)
|
||||
defer noStoreResp.Body.Close()
|
||||
noStoreBody, _ := io.ReadAll(noStoreResp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
noStoreBody, err := io.ReadAll(noStoreResp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, []byte("2"), noStoreBody)
|
||||
// Response not cached, returns updated response
|
||||
|
@ -278,11 +278,11 @@ func Test_Cache_WithSeveralRequests(t *testing.T) {
|
|||
for runs := 0; runs < 10; runs++ {
|
||||
for i := 0; i < 10; i++ {
|
||||
func(id int) {
|
||||
rsp, err := app.Test(httptest.NewRequest(http.MethodGet, fmt.Sprintf("/%d", id), nil))
|
||||
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, fmt.Sprintf("/%d", id), nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
defer func(Body io.ReadCloser) {
|
||||
err := Body.Close()
|
||||
defer func(body io.ReadCloser) {
|
||||
err := body.Close()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
}(rsp.Body)
|
||||
|
||||
|
@ -311,11 +311,11 @@ func Test_Cache_Invalid_Expiration(t *testing.T) {
|
|||
return c.SendString(now)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
cachedReq := httptest.NewRequest("GET", "/", nil)
|
||||
cachedReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
cachedResp, err := app.Test(cachedReq)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
|
@ -342,25 +342,25 @@ func Test_Cache_Get(t *testing.T) {
|
|||
return c.SendString(c.Query("cache"))
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("POST", "/?cache=123", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=123", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "123", string(body))
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest("POST", "/?cache=12345", nil))
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=12345", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "12345", string(body))
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest("GET", "/get?cache=123", nil))
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=123", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "123", string(body))
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest("GET", "/get?cache=12345", nil))
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=12345", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
@ -384,25 +384,25 @@ func Test_Cache_Post(t *testing.T) {
|
|||
return c.SendString(c.Query("cache"))
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("POST", "/?cache=123", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=123", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "123", string(body))
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest("POST", "/?cache=12345", nil))
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=12345", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "123", string(body))
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest("GET", "/get?cache=123", nil))
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=123", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "123", string(body))
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest("GET", "/get?cache=12345", nil))
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=12345", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
@ -420,14 +420,14 @@ func Test_Cache_NothingToCache(t *testing.T) {
|
|||
return c.SendString(time.Now().String())
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
respCached, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
bodyCached, err := io.ReadAll(respCached.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
@ -457,22 +457,22 @@ func Test_Cache_CustomNext(t *testing.T) {
|
|||
return c.Status(fiber.StatusInternalServerError).SendString(time.Now().String())
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
respCached, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
bodyCached, err := io.ReadAll(respCached.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, bytes.Equal(body, bodyCached))
|
||||
utils.AssertEqual(t, true, respCached.Header.Get(fiber.HeaderCacheControl) != "")
|
||||
|
||||
_, err = app.Test(httptest.NewRequest("GET", "/error", nil))
|
||||
_, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/error", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
errRespCached, err := app.Test(httptest.NewRequest("GET", "/error", nil))
|
||||
errRespCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/error", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, errRespCached.Header.Get(fiber.HeaderCacheControl) == "")
|
||||
}
|
||||
|
@ -491,7 +491,7 @@ func Test_CustomKey(t *testing.T) {
|
|||
return c.SendString("hi")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
_, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, called)
|
||||
|
@ -505,7 +505,9 @@ func Test_CustomExpiration(t *testing.T) {
|
|||
var newCacheTime int
|
||||
app.Use(New(Config{ExpirationGenerator: func(c *fiber.Ctx, cfg *Config) time.Duration {
|
||||
called = true
|
||||
newCacheTime, _ = strconv.Atoi(c.GetRespHeader("Cache-Time", "600"))
|
||||
var err error
|
||||
newCacheTime, err = strconv.Atoi(c.GetRespHeader("Cache-Time", "600"))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
return time.Second * time.Duration(newCacheTime)
|
||||
}}))
|
||||
|
||||
|
@ -515,7 +517,7 @@ func Test_CustomExpiration(t *testing.T) {
|
|||
return c.SendString(now)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, called)
|
||||
utils.AssertEqual(t, 1, newCacheTime)
|
||||
|
@ -523,7 +525,7 @@ func Test_CustomExpiration(t *testing.T) {
|
|||
// Sleep until the cache is expired
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
cachedResp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
cachedResp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
|
@ -536,7 +538,7 @@ func Test_CustomExpiration(t *testing.T) {
|
|||
}
|
||||
|
||||
// Next response should be cached
|
||||
cachedRespNextRound, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
cachedRespNextRound, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
cachedBodyNextRound, err := io.ReadAll(cachedRespNextRound.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
@ -559,12 +561,12 @@ func Test_AdditionalE2EResponseHeaders(t *testing.T) {
|
|||
return c.SendString("hi")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "foobar", resp.Header.Get("X-Foobar"))
|
||||
|
||||
req = httptest.NewRequest("GET", "/", nil)
|
||||
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
resp, err = app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "foobar", resp.Header.Get("X-Foobar"))
|
||||
|
@ -594,19 +596,19 @@ func Test_CacheHeader(t *testing.T) {
|
|||
return c.Status(fiber.StatusInternalServerError).SendString(time.Now().String())
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache"))
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, cacheHit, resp.Header.Get("X-Cache"))
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest("POST", "/?cache=12345", nil))
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=12345", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, cacheUnreachable, resp.Header.Get("X-Cache"))
|
||||
|
||||
errRespCached, err := app.Test(httptest.NewRequest("GET", "/error", nil))
|
||||
errRespCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/error", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, cacheUnreachable, errRespCached.Header.Get("X-Cache"))
|
||||
}
|
||||
|
@ -622,12 +624,12 @@ func Test_Cache_WithHead(t *testing.T) {
|
|||
return c.SendString(now)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("HEAD", "/", nil)
|
||||
req := httptest.NewRequest(fiber.MethodHead, "/", nil)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache"))
|
||||
|
||||
cachedReq := httptest.NewRequest("HEAD", "/", nil)
|
||||
cachedReq := httptest.NewRequest(fiber.MethodHead, "/", nil)
|
||||
cachedResp, err := app.Test(cachedReq)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, cacheHit, cachedResp.Header.Get("X-Cache"))
|
||||
|
@ -649,28 +651,28 @@ func Test_Cache_WithHeadThenGet(t *testing.T) {
|
|||
return c.SendString(c.Query("cache"))
|
||||
})
|
||||
|
||||
headResp, err := app.Test(httptest.NewRequest("HEAD", "/?cache=123", nil))
|
||||
headResp, err := app.Test(httptest.NewRequest(fiber.MethodHead, "/?cache=123", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
headBody, err := io.ReadAll(headResp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "", string(headBody))
|
||||
utils.AssertEqual(t, cacheMiss, headResp.Header.Get("X-Cache"))
|
||||
|
||||
headResp, err = app.Test(httptest.NewRequest("HEAD", "/?cache=123", nil))
|
||||
headResp, err = app.Test(httptest.NewRequest(fiber.MethodHead, "/?cache=123", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
headBody, err = io.ReadAll(headResp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "", string(headBody))
|
||||
utils.AssertEqual(t, cacheHit, headResp.Header.Get("X-Cache"))
|
||||
|
||||
getResp, err := app.Test(httptest.NewRequest("GET", "/?cache=123", nil))
|
||||
getResp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?cache=123", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
getBody, err := io.ReadAll(getResp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "123", string(getBody))
|
||||
utils.AssertEqual(t, cacheMiss, getResp.Header.Get("X-Cache"))
|
||||
|
||||
getResp, err = app.Test(httptest.NewRequest("GET", "/?cache=123", nil))
|
||||
getResp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/?cache=123", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
getBody, err = io.ReadAll(getResp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
@ -691,7 +693,7 @@ func Test_CustomCacheHeader(t *testing.T) {
|
|||
return c.SendString("Hello, World!")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, cacheMiss, resp.Header.Get("Cache-Status"))
|
||||
}
|
||||
|
@ -702,7 +704,7 @@ func Test_CustomCacheHeader(t *testing.T) {
|
|||
func stableAscendingExpiration() func(c1 *fiber.Ctx, c2 *Config) time.Duration {
|
||||
i := 0
|
||||
return func(c1 *fiber.Ctx, c2 *Config) time.Duration {
|
||||
i += 1
|
||||
i++
|
||||
return time.Hour * time.Duration(i)
|
||||
}
|
||||
}
|
||||
|
@ -738,7 +740,7 @@ func Test_Cache_MaxBytesOrder(t *testing.T) {
|
|||
}
|
||||
|
||||
for idx, tcase := range cases {
|
||||
rsp, err := app.Test(httptest.NewRequest("GET", tcase[0], nil))
|
||||
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, tcase[0], nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, tcase[1], rsp.Header.Get("X-Cache"), fmt.Sprintf("Case %v", idx))
|
||||
}
|
||||
|
@ -756,7 +758,8 @@ func Test_Cache_MaxBytesSizes(t *testing.T) {
|
|||
|
||||
app.Get("/*", func(c *fiber.Ctx) error {
|
||||
path := c.Context().URI().LastPathSegment()
|
||||
size, _ := strconv.Atoi(string(path))
|
||||
size, err := strconv.Atoi(string(path))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
return c.Send(make([]byte, size))
|
||||
})
|
||||
|
||||
|
@ -772,7 +775,7 @@ func Test_Cache_MaxBytesSizes(t *testing.T) {
|
|||
}
|
||||
|
||||
for idx, tcase := range cases {
|
||||
rsp, err := app.Test(httptest.NewRequest("GET", tcase[0], nil))
|
||||
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, tcase[0], nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, tcase[1], rsp.Header.Get("X-Cache"), fmt.Sprintf("Case %v", idx))
|
||||
}
|
||||
|
@ -785,14 +788,14 @@ func Benchmark_Cache(b *testing.B) {
|
|||
app.Use(New())
|
||||
|
||||
app.Get("/demo", func(c *fiber.Ctx) error {
|
||||
data, _ := os.ReadFile("../../.github/README.md")
|
||||
data, _ := os.ReadFile("../../.github/README.md") //nolint:errcheck // We're inside a benchmark
|
||||
return c.Status(fiber.StatusTeapot).Send(data)
|
||||
})
|
||||
|
||||
h := app.Handler()
|
||||
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
fctx.Request.Header.SetMethod("GET")
|
||||
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
fctx.Request.SetRequestURI("/demo")
|
||||
|
||||
b.ReportAllocs()
|
||||
|
@ -815,14 +818,14 @@ func Benchmark_Cache_Storage(b *testing.B) {
|
|||
}))
|
||||
|
||||
app.Get("/demo", func(c *fiber.Ctx) error {
|
||||
data, _ := os.ReadFile("../../.github/README.md")
|
||||
data, _ := os.ReadFile("../../.github/README.md") //nolint:errcheck // We're inside a benchmark
|
||||
return c.Status(fiber.StatusTeapot).Send(data)
|
||||
})
|
||||
|
||||
h := app.Handler()
|
||||
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
fctx.Request.Header.SetMethod("GET")
|
||||
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
fctx.Request.SetRequestURI("/demo")
|
||||
|
||||
b.ReportAllocs()
|
||||
|
@ -850,7 +853,7 @@ func Benchmark_Cache_AdditionalHeaders(b *testing.B) {
|
|||
h := app.Handler()
|
||||
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
fctx.Request.Header.SetMethod("GET")
|
||||
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
fctx.Request.SetRequestURI("/demo")
|
||||
|
||||
b.ReportAllocs()
|
||||
|
@ -882,7 +885,7 @@ func Benchmark_Cache_MaxSize(b *testing.B) {
|
|||
|
||||
h := app.Handler()
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
fctx.Request.Header.SetMethod("GET")
|
||||
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
@ -49,10 +49,10 @@ type Config struct {
|
|||
// Default: an in memory store for this process only
|
||||
Storage fiber.Storage
|
||||
|
||||
// Deprecated, use Storage instead
|
||||
// Deprecated: Use Storage instead
|
||||
Store fiber.Storage
|
||||
|
||||
// Deprecated, use KeyGenerator instead
|
||||
// Deprecated: Use KeyGenerator instead
|
||||
Key func(*fiber.Ctx) string
|
||||
|
||||
// allows you to store additional headers generated by next middlewares & handler
|
||||
|
@ -75,6 +75,8 @@ type Config struct {
|
|||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
//
|
||||
//nolint:gochecknoglobals // Using a global var is fine here
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
Expiration: 1 * time.Minute,
|
||||
|
@ -102,11 +104,11 @@ func configDefault(config ...Config) Config {
|
|||
|
||||
// Set default values
|
||||
if cfg.Store != nil {
|
||||
fmt.Println("[CACHE] Store is deprecated, please use Storage")
|
||||
log.Printf("[CACHE] Store is deprecated, please use Storage\n")
|
||||
cfg.Storage = cfg.Store
|
||||
}
|
||||
if cfg.Key != nil {
|
||||
fmt.Println("[CACHE] Key is deprecated, please use KeyGenerator")
|
||||
log.Printf("[CACHE] Key is deprecated, please use KeyGenerator\n")
|
||||
cfg.KeyGenerator = cfg.Key
|
||||
}
|
||||
if cfg.Next == nil {
|
||||
|
|
|
@ -41,7 +41,7 @@ func (h indexedHeap) Swap(i, j int) {
|
|||
}
|
||||
|
||||
func (h *indexedHeap) Push(x interface{}) {
|
||||
h.pushInternal(x.(heapEntry))
|
||||
h.pushInternal(x.(heapEntry)) //nolint:forcetypeassert // Forced type assertion required to implement the heap.Interface interface
|
||||
}
|
||||
|
||||
func (h *indexedHeap) Pop() interface{} {
|
||||
|
@ -65,7 +65,7 @@ func (h *indexedHeap) put(key string, exp uint64, bytes uint) int {
|
|||
idx = h.entries[:n+1][n].idx
|
||||
} else {
|
||||
idx = h.maxidx
|
||||
h.maxidx += 1
|
||||
h.maxidx++
|
||||
h.indices = append(h.indices, idx)
|
||||
}
|
||||
// Push manually to avoid allocation
|
||||
|
@ -77,7 +77,7 @@ func (h *indexedHeap) put(key string, exp uint64, bytes uint) int {
|
|||
}
|
||||
|
||||
func (h *indexedHeap) removeInternal(realIdx int) (string, uint) {
|
||||
x := heap.Remove(h, realIdx).(heapEntry)
|
||||
x := heap.Remove(h, realIdx).(heapEntry) //nolint:forcetypeassert,errcheck // Forced type assertion required to implement the heap.Interface interface
|
||||
return x.key, x.bytes
|
||||
}
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ func newManager(storage fiber.Storage) *manager {
|
|||
|
||||
// acquire returns an *entry from the sync.Pool
|
||||
func (m *manager) acquire() *item {
|
||||
return m.pool.Get().(*item)
|
||||
return m.pool.Get().(*item) //nolint:forcetypeassert // We store nothing else in the pool
|
||||
}
|
||||
|
||||
// release and reset *entry to sync.Pool
|
||||
|
@ -69,38 +69,47 @@ func (m *manager) release(e *item) {
|
|||
}
|
||||
|
||||
// get data from storage or memory
|
||||
func (m *manager) get(key string) (it *item) {
|
||||
func (m *manager) get(key string) *item {
|
||||
var it *item
|
||||
if m.storage != nil {
|
||||
it = m.acquire()
|
||||
if raw, _ := m.storage.Get(key); raw != nil {
|
||||
raw, err := m.storage.Get(key)
|
||||
if err != nil {
|
||||
return it
|
||||
}
|
||||
if raw != nil {
|
||||
if _, err := it.UnmarshalMsg(raw); err != nil {
|
||||
return
|
||||
return it
|
||||
}
|
||||
}
|
||||
return
|
||||
return it
|
||||
}
|
||||
if it, _ = m.memory.Get(key).(*item); it == nil {
|
||||
if it, _ = m.memory.Get(key).(*item); it == nil { //nolint:errcheck // We store nothing else in the pool
|
||||
it = m.acquire()
|
||||
return it
|
||||
}
|
||||
return
|
||||
return it
|
||||
}
|
||||
|
||||
// get raw data from storage or memory
|
||||
func (m *manager) getRaw(key string) (raw []byte) {
|
||||
func (m *manager) getRaw(key string) []byte {
|
||||
var raw []byte
|
||||
if m.storage != nil {
|
||||
raw, _ = m.storage.Get(key)
|
||||
raw, _ = m.storage.Get(key) //nolint:errcheck // TODO: Handle error here
|
||||
} else {
|
||||
raw, _ = m.memory.Get(key).([]byte)
|
||||
raw, _ = m.memory.Get(key).([]byte) //nolint:errcheck // TODO: Handle error here
|
||||
}
|
||||
return
|
||||
return raw
|
||||
}
|
||||
|
||||
// set data to storage or memory
|
||||
func (m *manager) set(key string, it *item, exp time.Duration) {
|
||||
if m.storage != nil {
|
||||
if raw, err := it.MarshalMsg(nil); err == nil {
|
||||
_ = m.storage.Set(key, raw, exp)
|
||||
_ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Handle error here
|
||||
}
|
||||
// we can release data because it's serialized to database
|
||||
m.release(it)
|
||||
} else {
|
||||
m.memory.Set(key, it, exp)
|
||||
}
|
||||
|
@ -109,16 +118,16 @@ func (m *manager) set(key string, it *item, exp time.Duration) {
|
|||
// set data to storage or memory
|
||||
func (m *manager) setRaw(key string, raw []byte, exp time.Duration) {
|
||||
if m.storage != nil {
|
||||
_ = m.storage.Set(key, raw, exp)
|
||||
_ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Handle error here
|
||||
} else {
|
||||
m.memory.Set(key, raw, exp)
|
||||
}
|
||||
}
|
||||
|
||||
// delete data from storage or memory
|
||||
func (m *manager) delete(key string) {
|
||||
func (m *manager) del(key string) {
|
||||
if m.storage != nil {
|
||||
_ = m.storage.Delete(key)
|
||||
_ = m.storage.Delete(key) //nolint:errcheck // TODO: Handle error here
|
||||
} else {
|
||||
m.memory.Delete(key)
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package compress
|
|||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
|
|
|
@ -12,8 +12,10 @@ import (
|
|||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
//nolint:gochecknoglobals // Using a global var is fine here
|
||||
var filedata []byte
|
||||
|
||||
//nolint:gochecknoinits // init() is used to populate a global var from a README file
|
||||
func init() {
|
||||
dat, err := os.ReadFile("../../.github/README.md")
|
||||
if err != nil {
|
||||
|
@ -34,7 +36,7 @@ func Test_Compress_Gzip(t *testing.T) {
|
|||
return c.Send(filedata)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
|
@ -64,7 +66,7 @@ func Test_Compress_Different_Level(t *testing.T) {
|
|||
return c.Send(filedata)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
|
@ -90,7 +92,7 @@ func Test_Compress_Deflate(t *testing.T) {
|
|||
return c.Send(filedata)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
req.Header.Set("Accept-Encoding", "deflate")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
|
@ -114,7 +116,7 @@ func Test_Compress_Brotli(t *testing.T) {
|
|||
return c.Send(filedata)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
req.Header.Set("Accept-Encoding", "br")
|
||||
|
||||
resp, err := app.Test(req, 10000)
|
||||
|
@ -138,7 +140,7 @@ func Test_Compress_Disabled(t *testing.T) {
|
|||
return c.Send(filedata)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
req.Header.Set("Accept-Encoding", "br")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
|
@ -162,7 +164,7 @@ func Test_Compress_Next_Error(t *testing.T) {
|
|||
return errors.New("next error")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
|
@ -185,7 +187,7 @@ func Test_Compress_Next(t *testing.T) {
|
|||
},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
|
|
@ -33,6 +33,8 @@ const (
|
|||
)
|
||||
|
||||
// ConfigDefault is the default config
|
||||
//
|
||||
//nolint:gochecknoglobals // Using a global var is fine here
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
Level: LevelDefault,
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package cors
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
|
@ -54,6 +53,8 @@ type Config struct {
|
|||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
//
|
||||
//nolint:gochecknoglobals // Using a global var is fine here
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
AllowOrigins: "*",
|
||||
|
@ -128,7 +129,7 @@ func New(config ...Config) fiber.Handler {
|
|||
}
|
||||
|
||||
// Simple request
|
||||
if c.Method() != http.MethodOptions {
|
||||
if c.Method() != fiber.MethodOptions {
|
||||
c.Vary(fiber.HeaderOrigin)
|
||||
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
|
@ -237,7 +238,7 @@ func Test_CORS_Next(t *testing.T) {
|
|||
},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package cors
|
||||
|
||||
import "strings"
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func matchScheme(domain, pattern string) bool {
|
||||
didx := strings.Index(domain, ":")
|
||||
|
@ -20,18 +22,20 @@ func matchSubdomain(domain, pattern string) bool {
|
|||
}
|
||||
domAuth := domain[didx+3:]
|
||||
// to avoid long loop by invalid long domain
|
||||
if len(domAuth) > 253 {
|
||||
const maxDomainLen = 253
|
||||
if len(domAuth) > maxDomainLen {
|
||||
return false
|
||||
}
|
||||
patAuth := pattern[pidx+3:]
|
||||
|
||||
domComp := strings.Split(domAuth, ".")
|
||||
patComp := strings.Split(patAuth, ".")
|
||||
for i := len(domComp)/2 - 1; i >= 0; i-- {
|
||||
const divHalf = 2
|
||||
for i := len(domComp)/divHalf - 1; i >= 0; i-- {
|
||||
opp := len(domComp) - 1 - i
|
||||
domComp[i], domComp[opp] = domComp[opp], domComp[i]
|
||||
}
|
||||
for i := len(patComp)/2 - 1; i >= 0; i-- {
|
||||
for i := len(patComp)/divHalf - 1; i >= 0; i-- {
|
||||
opp := len(patComp) - 1 - i
|
||||
patComp[i], patComp[opp] = patComp[opp], patComp[i]
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package csrf
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/textproto"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -80,13 +80,13 @@ type Config struct {
|
|||
// Optional. Default: utils.UUID
|
||||
KeyGenerator func() string
|
||||
|
||||
// Deprecated, please use Expiration
|
||||
// Deprecated: Please use Expiration
|
||||
CookieExpires time.Duration
|
||||
|
||||
// Deprecated, please use Cookie* related fields
|
||||
// Deprecated: Please use Cookie* related fields
|
||||
Cookie *fiber.Cookie
|
||||
|
||||
// Deprecated, please use KeyLookup
|
||||
// Deprecated: Please use KeyLookup
|
||||
TokenLookup string
|
||||
|
||||
// ErrorHandler is executed when an error is returned from fiber.Handler.
|
||||
|
@ -105,6 +105,8 @@ type Config struct {
|
|||
const HeaderName = "X-Csrf-Token"
|
||||
|
||||
// ConfigDefault is the default config
|
||||
//
|
||||
//nolint:gochecknoglobals // Using a global var is fine here
|
||||
var ConfigDefault = Config{
|
||||
KeyLookup: "header:" + HeaderName,
|
||||
CookieName: "csrf_",
|
||||
|
@ -116,7 +118,7 @@ var ConfigDefault = Config{
|
|||
}
|
||||
|
||||
// default ErrorHandler that process return error from fiber.Handler
|
||||
var defaultErrorHandler = func(c *fiber.Ctx, err error) error {
|
||||
func defaultErrorHandler(_ *fiber.Ctx, _ error) error {
|
||||
return fiber.ErrForbidden
|
||||
}
|
||||
|
||||
|
@ -132,15 +134,15 @@ func configDefault(config ...Config) Config {
|
|||
|
||||
// Set default values
|
||||
if cfg.TokenLookup != "" {
|
||||
fmt.Println("[CSRF] TokenLookup is deprecated, please use KeyLookup")
|
||||
log.Printf("[CSRF] TokenLookup is deprecated, please use KeyLookup\n")
|
||||
cfg.KeyLookup = cfg.TokenLookup
|
||||
}
|
||||
if int(cfg.CookieExpires.Seconds()) > 0 {
|
||||
fmt.Println("[CSRF] CookieExpires is deprecated, please use Expiration")
|
||||
log.Printf("[CSRF] CookieExpires is deprecated, please use Expiration\n")
|
||||
cfg.Expiration = cfg.CookieExpires
|
||||
}
|
||||
if cfg.Cookie != nil {
|
||||
fmt.Println("[CSRF] Cookie is deprecated, please use Cookie* related fields")
|
||||
log.Printf("[CSRF] Cookie is deprecated, please use Cookie* related fields\n")
|
||||
if cfg.Cookie.Name != "" {
|
||||
cfg.CookieName = cfg.Cookie.Name
|
||||
}
|
||||
|
@ -178,7 +180,8 @@ func configDefault(config ...Config) Config {
|
|||
// Generate the correct extractor to get the token from the correct location
|
||||
selectors := strings.Split(cfg.KeyLookup, ":")
|
||||
|
||||
if len(selectors) != 2 {
|
||||
const numParts = 2
|
||||
if len(selectors) != numParts {
|
||||
panic("[CSRF] KeyLookup must in the form of <source>:<key>")
|
||||
}
|
||||
|
||||
|
|
|
@ -7,9 +7,7 @@ import (
|
|||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
var (
|
||||
errTokenNotFound = errors.New("csrf token not found")
|
||||
)
|
||||
var errTokenNotFound = errors.New("csrf token not found")
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
|
@ -22,7 +20,7 @@ func New(config ...Config) fiber.Handler {
|
|||
dummyValue := []byte{'+'}
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) (err error) {
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
|
@ -39,7 +37,7 @@ func New(config ...Config) fiber.Handler {
|
|||
// Assume that anything not defined as 'safe' by RFC7231 needs protection
|
||||
|
||||
// Extract token from client request i.e. header, query, param, form or cookie
|
||||
token, err = cfg.Extractor(c)
|
||||
token, err := cfg.Extractor(c)
|
||||
if err != nil {
|
||||
return cfg.ErrorHandler(c, err)
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
|
@ -23,7 +24,7 @@ func Test_CSRF(t *testing.T) {
|
|||
h := app.Handler()
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
|
||||
methods := [4]string{"GET", "HEAD", "OPTIONS", "TRACE"}
|
||||
methods := [4]string{fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace}
|
||||
|
||||
for _, method := range methods {
|
||||
// Generate CSRF token
|
||||
|
@ -33,14 +34,14 @@ func Test_CSRF(t *testing.T) {
|
|||
// Without CSRF cookie
|
||||
ctx.Request.Reset()
|
||||
ctx.Response.Reset()
|
||||
ctx.Request.Header.SetMethod("POST")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
|
||||
|
||||
// Empty/invalid CSRF token
|
||||
ctx.Request.Reset()
|
||||
ctx.Response.Reset()
|
||||
ctx.Request.Header.SetMethod("POST")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||
ctx.Request.Header.Set(HeaderName, "johndoe")
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
|
||||
|
@ -55,7 +56,7 @@ func Test_CSRF(t *testing.T) {
|
|||
|
||||
ctx.Request.Reset()
|
||||
ctx.Response.Reset()
|
||||
ctx.Request.Header.SetMethod("POST")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||
ctx.Request.Header.Set(HeaderName, token)
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||
|
@ -72,7 +73,7 @@ func Test_CSRF_Next(t *testing.T) {
|
|||
},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
@ -92,7 +93,7 @@ func Test_CSRF_Invalid_KeyLookup(t *testing.T) {
|
|||
|
||||
h := app.Handler()
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
h(ctx)
|
||||
}
|
||||
|
||||
|
@ -110,7 +111,7 @@ func Test_CSRF_From_Form(t *testing.T) {
|
|||
ctx := &fasthttp.RequestCtx{}
|
||||
|
||||
// Invalid CSRF token
|
||||
ctx.Request.Header.SetMethod("POST")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm)
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
|
||||
|
@ -118,12 +119,12 @@ func Test_CSRF_From_Form(t *testing.T) {
|
|||
// Generate CSRF token
|
||||
ctx.Request.Reset()
|
||||
ctx.Response.Reset()
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
h(ctx)
|
||||
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
|
||||
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
|
||||
|
||||
ctx.Request.Header.SetMethod("POST")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm)
|
||||
ctx.Request.SetBodyString("_csrf=" + token)
|
||||
h(ctx)
|
||||
|
@ -144,7 +145,7 @@ func Test_CSRF_From_Query(t *testing.T) {
|
|||
ctx := &fasthttp.RequestCtx{}
|
||||
|
||||
// Invalid CSRF token
|
||||
ctx.Request.Header.SetMethod("POST")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||
ctx.Request.SetRequestURI("/?_csrf=" + utils.UUID())
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
|
||||
|
@ -152,7 +153,7 @@ func Test_CSRF_From_Query(t *testing.T) {
|
|||
// Generate CSRF token
|
||||
ctx.Request.Reset()
|
||||
ctx.Response.Reset()
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
ctx.Request.SetRequestURI("/")
|
||||
h(ctx)
|
||||
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
|
||||
|
@ -161,7 +162,7 @@ func Test_CSRF_From_Query(t *testing.T) {
|
|||
ctx.Request.Reset()
|
||||
ctx.Response.Reset()
|
||||
ctx.Request.SetRequestURI("/?_csrf=" + token)
|
||||
ctx.Request.Header.SetMethod("POST")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||
utils.AssertEqual(t, "OK", string(ctx.Response.Body()))
|
||||
|
@ -181,7 +182,7 @@ func Test_CSRF_From_Param(t *testing.T) {
|
|||
ctx := &fasthttp.RequestCtx{}
|
||||
|
||||
// Invalid CSRF token
|
||||
ctx.Request.Header.SetMethod("POST")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||
ctx.Request.SetRequestURI("/" + utils.UUID())
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
|
||||
|
@ -189,7 +190,7 @@ func Test_CSRF_From_Param(t *testing.T) {
|
|||
// Generate CSRF token
|
||||
ctx.Request.Reset()
|
||||
ctx.Response.Reset()
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
ctx.Request.SetRequestURI("/" + utils.UUID())
|
||||
h(ctx)
|
||||
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
|
||||
|
@ -198,7 +199,7 @@ func Test_CSRF_From_Param(t *testing.T) {
|
|||
ctx.Request.Reset()
|
||||
ctx.Response.Reset()
|
||||
ctx.Request.SetRequestURI("/" + token)
|
||||
ctx.Request.Header.SetMethod("POST")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||
utils.AssertEqual(t, "OK", string(ctx.Response.Body()))
|
||||
|
@ -218,7 +219,7 @@ func Test_CSRF_From_Cookie(t *testing.T) {
|
|||
ctx := &fasthttp.RequestCtx{}
|
||||
|
||||
// Invalid CSRF token
|
||||
ctx.Request.Header.SetMethod("POST")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||
ctx.Request.SetRequestURI("/")
|
||||
ctx.Request.Header.Set(fiber.HeaderCookie, "csrf="+utils.UUID()+";")
|
||||
h(ctx)
|
||||
|
@ -227,7 +228,7 @@ func Test_CSRF_From_Cookie(t *testing.T) {
|
|||
// Generate CSRF token
|
||||
ctx.Request.Reset()
|
||||
ctx.Response.Reset()
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
ctx.Request.SetRequestURI("/")
|
||||
h(ctx)
|
||||
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
|
||||
|
@ -235,7 +236,7 @@ func Test_CSRF_From_Cookie(t *testing.T) {
|
|||
|
||||
ctx.Request.Reset()
|
||||
ctx.Response.Reset()
|
||||
ctx.Request.Header.SetMethod("POST")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||
ctx.Request.Header.Set(fiber.HeaderCookie, "csrf="+token+";")
|
||||
ctx.Request.SetRequestURI("/")
|
||||
h(ctx)
|
||||
|
@ -268,7 +269,7 @@ func Test_CSRF_From_Custom(t *testing.T) {
|
|||
ctx := &fasthttp.RequestCtx{}
|
||||
|
||||
// Invalid CSRF token
|
||||
ctx.Request.Header.SetMethod("POST")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain)
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
|
||||
|
@ -276,12 +277,12 @@ func Test_CSRF_From_Custom(t *testing.T) {
|
|||
// Generate CSRF token
|
||||
ctx.Request.Reset()
|
||||
ctx.Response.Reset()
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
h(ctx)
|
||||
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
|
||||
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
|
||||
|
||||
ctx.Request.Header.SetMethod("POST")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain)
|
||||
ctx.Request.SetBodyString("_csrf=" + token)
|
||||
h(ctx)
|
||||
|
@ -307,13 +308,13 @@ func Test_CSRF_ErrorHandler_InvalidToken(t *testing.T) {
|
|||
ctx := &fasthttp.RequestCtx{}
|
||||
|
||||
// Generate CSRF token
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
h(ctx)
|
||||
|
||||
// invalid CSRF token
|
||||
ctx.Request.Reset()
|
||||
ctx.Response.Reset()
|
||||
ctx.Request.Header.SetMethod("POST")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||
ctx.Request.Header.Set(HeaderName, "johndoe")
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 419, ctx.Response.StatusCode())
|
||||
|
@ -339,69 +340,69 @@ func Test_CSRF_ErrorHandler_EmptyToken(t *testing.T) {
|
|||
ctx := &fasthttp.RequestCtx{}
|
||||
|
||||
// Generate CSRF token
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
h(ctx)
|
||||
|
||||
// empty CSRF token
|
||||
ctx.Request.Reset()
|
||||
ctx.Response.Reset()
|
||||
ctx.Request.Header.SetMethod("POST")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 419, ctx.Response.StatusCode())
|
||||
utils.AssertEqual(t, "empty CSRF token", string(ctx.Response.Body()))
|
||||
}
|
||||
|
||||
// TODO: use this test case and make the unsafe header value bug from https://github.com/gofiber/fiber/issues/2045 reproducible and permanently fixed/tested by this testcase
|
||||
//func Test_CSRF_UnsafeHeaderValue(t *testing.T) {
|
||||
// func Test_CSRF_UnsafeHeaderValue(t *testing.T) {
|
||||
// t.Parallel()
|
||||
// app := fiber.New()
|
||||
//
|
||||
// app.Use(New())
|
||||
// app.Get("/", func(c *fiber.Ctx) error {
|
||||
// return c.SendStatus(fiber.StatusOK)
|
||||
// })
|
||||
// app.Get("/test", func(c *fiber.Ctx) error {
|
||||
// return c.SendStatus(fiber.StatusOK)
|
||||
// })
|
||||
// app.Post("/", func(c *fiber.Ctx) error {
|
||||
// return c.SendStatus(fiber.StatusOK)
|
||||
// })
|
||||
//
|
||||
// resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
// utils.AssertEqual(t, nil, err)
|
||||
// utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
//
|
||||
// var token string
|
||||
// for _, c := range resp.Cookies() {
|
||||
// if c.Name != ConfigDefault.CookieName {
|
||||
// continue
|
||||
// }
|
||||
// token = c.Value
|
||||
// break
|
||||
// }
|
||||
//
|
||||
// fmt.Println("token", token)
|
||||
//
|
||||
// getReq := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
// getReq.Header.Set(HeaderName, token)
|
||||
// resp, err = app.Test(getReq)
|
||||
//
|
||||
// getReq = httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
// getReq.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||
// getReq.Header.Set(fiber.HeaderCacheControl, "no")
|
||||
// getReq.Header.Set(HeaderName, token)
|
||||
//
|
||||
// resp, err = app.Test(getReq)
|
||||
//
|
||||
// getReq.Header.Set(fiber.HeaderAccept, "*/*")
|
||||
// getReq.Header.Del(HeaderName)
|
||||
// resp, err = app.Test(getReq)
|
||||
//
|
||||
// postReq := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
// postReq.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||
// postReq.Header.Set(HeaderName, token)
|
||||
// resp, err = app.Test(postReq)
|
||||
//}
|
||||
// app := fiber.New()
|
||||
|
||||
// app.Use(New())
|
||||
// app.Get("/", func(c *fiber.Ctx) error {
|
||||
// return c.SendStatus(fiber.StatusOK)
|
||||
// })
|
||||
// app.Get("/test", func(c *fiber.Ctx) error {
|
||||
// return c.SendStatus(fiber.StatusOK)
|
||||
// })
|
||||
// app.Post("/", func(c *fiber.Ctx) error {
|
||||
// return c.SendStatus(fiber.StatusOK)
|
||||
// })
|
||||
|
||||
// resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
// utils.AssertEqual(t, nil, err)
|
||||
// utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
// var token string
|
||||
// for _, c := range resp.Cookies() {
|
||||
// if c.Name != ConfigDefault.CookieName {
|
||||
// continue
|
||||
// }
|
||||
// token = c.Value
|
||||
// break
|
||||
// }
|
||||
|
||||
// fmt.Println("token", token)
|
||||
|
||||
// getReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
// getReq.Header.Set(HeaderName, token)
|
||||
// resp, err = app.Test(getReq)
|
||||
|
||||
// getReq = httptest.NewRequest(fiber.MethodGet, "/test", nil)
|
||||
// getReq.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||
// getReq.Header.Set(fiber.HeaderCacheControl, "no")
|
||||
// getReq.Header.Set(HeaderName, token)
|
||||
|
||||
// resp, err = app.Test(getReq)
|
||||
|
||||
// getReq.Header.Set(fiber.HeaderAccept, "*/*")
|
||||
// getReq.Header.Del(HeaderName)
|
||||
// resp, err = app.Test(getReq)
|
||||
|
||||
// postReq := httptest.NewRequest(fiber.MethodPost, "/", nil)
|
||||
// postReq.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||
// postReq.Header.Set(HeaderName, token)
|
||||
// resp, err = app.Test(postReq)
|
||||
// }
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Middleware_CSRF_Check -benchmem -count=4
|
||||
func Benchmark_Middleware_CSRF_Check(b *testing.B) {
|
||||
|
@ -417,12 +418,12 @@ func Benchmark_Middleware_CSRF_Check(b *testing.B) {
|
|||
ctx := &fasthttp.RequestCtx{}
|
||||
|
||||
// Generate CSRF token
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
h(ctx)
|
||||
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
|
||||
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
|
||||
|
||||
ctx.Request.Header.SetMethod("POST")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||
ctx.Request.Header.Set(HeaderName, token)
|
||||
|
||||
b.ReportAllocs()
|
||||
|
@ -449,7 +450,7 @@ func Benchmark_Middleware_CSRF_GenerateToken(b *testing.B) {
|
|||
ctx := &fasthttp.RequestCtx{}
|
||||
|
||||
// Generate CSRF token
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
|
|
|
@ -41,74 +41,23 @@ func newManager(storage fiber.Storage) *manager {
|
|||
return manager
|
||||
}
|
||||
|
||||
// acquire returns an *entry from the sync.Pool
|
||||
func (m *manager) acquire() *item {
|
||||
return m.pool.Get().(*item)
|
||||
}
|
||||
|
||||
// release and reset *entry to sync.Pool
|
||||
func (m *manager) release(e *item) {
|
||||
// don't release item if we using memory storage
|
||||
if m.storage != nil {
|
||||
return
|
||||
}
|
||||
m.pool.Put(e)
|
||||
}
|
||||
|
||||
// get data from storage or memory
|
||||
func (m *manager) get(key string) (it *item) {
|
||||
if m.storage != nil {
|
||||
it = m.acquire()
|
||||
if raw, _ := m.storage.Get(key); raw != nil {
|
||||
if _, err := it.UnmarshalMsg(raw); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
if it, _ = m.memory.Get(key).(*item); it == nil {
|
||||
it = m.acquire()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// get raw data from storage or memory
|
||||
func (m *manager) getRaw(key string) (raw []byte) {
|
||||
func (m *manager) getRaw(key string) []byte {
|
||||
var raw []byte
|
||||
if m.storage != nil {
|
||||
raw, _ = m.storage.Get(key)
|
||||
raw, _ = m.storage.Get(key) //nolint:errcheck // TODO: Do not ignore error
|
||||
} else {
|
||||
raw, _ = m.memory.Get(key).([]byte)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// set data to storage or memory
|
||||
func (m *manager) set(key string, it *item, exp time.Duration) {
|
||||
if m.storage != nil {
|
||||
if raw, err := it.MarshalMsg(nil); err == nil {
|
||||
_ = m.storage.Set(key, raw, exp)
|
||||
}
|
||||
} else {
|
||||
// the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here
|
||||
m.memory.Set(utils.CopyString(key), it, exp)
|
||||
raw, _ = m.memory.Get(key).([]byte) //nolint:errcheck // TODO: Do not ignore error
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
// set data to storage or memory
|
||||
func (m *manager) setRaw(key string, raw []byte, exp time.Duration) {
|
||||
if m.storage != nil {
|
||||
_ = m.storage.Set(key, raw, exp)
|
||||
_ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Do not ignore error
|
||||
} else {
|
||||
// the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here
|
||||
m.memory.Set(utils.CopyString(key), raw, exp)
|
||||
}
|
||||
}
|
||||
|
||||
// delete data from storage or memory
|
||||
func (m *manager) delete(key string) {
|
||||
if m.storage != nil {
|
||||
_ = m.storage.Delete(key)
|
||||
} else {
|
||||
m.memory.Delete(key)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package encryptcookie
|
||||
|
||||
import "github.com/gofiber/fiber/v2"
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
|
@ -32,6 +34,8 @@ type Config struct {
|
|||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
//
|
||||
//nolint:gochecknoglobals // Using a global var is fine here
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
Except: []string{"csrf_"},
|
||||
|
|
|
@ -2,6 +2,7 @@ package encryptcookie
|
|||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
|
|
|
@ -7,9 +7,11 @@ import (
|
|||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
//nolint:gochecknoglobals // Using a global var is fine here
|
||||
var testKey = GenerateKey()
|
||||
|
||||
func Test_Middleware_Encrypt_Cookie(t *testing.T) {
|
||||
|
@ -35,14 +37,14 @@ func Test_Middleware_Encrypt_Cookie(t *testing.T) {
|
|||
|
||||
// Test empty cookie
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||
utils.AssertEqual(t, "value=", string(ctx.Response.Body()))
|
||||
|
||||
// Test invalid cookie
|
||||
ctx = &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
ctx.Request.Header.SetCookie("test", "Invalid")
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||
|
@ -54,18 +56,19 @@ func Test_Middleware_Encrypt_Cookie(t *testing.T) {
|
|||
|
||||
// Test valid cookie
|
||||
ctx = &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("POST")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||
|
||||
encryptedCookie := fasthttp.Cookie{}
|
||||
encryptedCookie.SetKey("test")
|
||||
utils.AssertEqual(t, true, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value")
|
||||
decryptedCookieValue, _ := DecryptCookie(string(encryptedCookie.Value()), testKey)
|
||||
decryptedCookieValue, err := DecryptCookie(string(encryptedCookie.Value()), testKey)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "SomeThing", decryptedCookieValue)
|
||||
|
||||
ctx = &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
ctx.Request.Header.SetCookie("test", string(encryptedCookie.Value()))
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||
|
@ -91,7 +94,7 @@ func Test_Encrypt_Cookie_Next(t *testing.T) {
|
|||
return nil
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "SomeThing", resp.Cookies()[0].Value)
|
||||
}
|
||||
|
@ -123,7 +126,7 @@ func Test_Encrypt_Cookie_Except(t *testing.T) {
|
|||
h := app.Handler()
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||
|
||||
|
@ -135,7 +138,8 @@ func Test_Encrypt_Cookie_Except(t *testing.T) {
|
|||
encryptedCookie := fasthttp.Cookie{}
|
||||
encryptedCookie.SetKey("test2")
|
||||
utils.AssertEqual(t, true, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value")
|
||||
decryptedCookieValue, _ := DecryptCookie(string(encryptedCookie.Value()), testKey)
|
||||
decryptedCookieValue, err := DecryptCookie(string(encryptedCookie.Value()), testKey)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "SomeThing", decryptedCookieValue)
|
||||
}
|
||||
|
||||
|
@ -169,18 +173,19 @@ func Test_Encrypt_Cookie_Custom_Encryptor(t *testing.T) {
|
|||
h := app.Handler()
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("POST")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||
|
||||
encryptedCookie := fasthttp.Cookie{}
|
||||
encryptedCookie.SetKey("test")
|
||||
utils.AssertEqual(t, true, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value")
|
||||
decodedBytes, _ := base64.StdEncoding.DecodeString(string(encryptedCookie.Value()))
|
||||
decodedBytes, err := base64.StdEncoding.DecodeString(string(encryptedCookie.Value()))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "SomeThing", string(decodedBytes))
|
||||
|
||||
ctx = &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
ctx.Request.Header.SetCookie("test", string(encryptedCookie.Value()))
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||
|
|
|
@ -6,47 +6,56 @@ import (
|
|||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// EncryptCookie Encrypts a cookie value with specific encryption key
|
||||
func EncryptCookie(value, key string) (string, error) {
|
||||
keyDecoded, _ := base64.StdEncoding.DecodeString(key)
|
||||
plaintext := []byte(value)
|
||||
keyDecoded, err := base64.StdEncoding.DecodeString(key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to base64-decode key: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(keyDecoded)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("failed to create AES cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("failed to create GCM mode: %w", err)
|
||||
}
|
||||
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("failed to read: %w", err)
|
||||
}
|
||||
|
||||
ciphertext := gcm.Seal(nonce, nonce, plaintext, nil)
|
||||
ciphertext := gcm.Seal(nonce, nonce, []byte(value), nil)
|
||||
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// DecryptCookie Decrypts a cookie value with specific encryption key
|
||||
func DecryptCookie(value, key string) (string, error) {
|
||||
keyDecoded, _ := base64.StdEncoding.DecodeString(key)
|
||||
enc, _ := base64.StdEncoding.DecodeString(value)
|
||||
keyDecoded, err := base64.StdEncoding.DecodeString(key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to base64-decode key: %w", err)
|
||||
}
|
||||
enc, err := base64.StdEncoding.DecodeString(value)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to base64-decode value: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(keyDecoded)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("failed to create AES cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("failed to create GCM mode: %w", err)
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
|
@ -59,7 +68,7 @@ func DecryptCookie(value, key string) (string, error) {
|
|||
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("failed to decrypt ciphertext: %w", err)
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
|
@ -67,7 +76,8 @@ func DecryptCookie(value, key string) (string, error) {
|
|||
|
||||
// GenerateKey Generates an encryption key
|
||||
func GenerateKey() string {
|
||||
ret := make([]byte, 32)
|
||||
const keyLen = 32
|
||||
ret := make([]byte, keyLen)
|
||||
|
||||
if _, err := rand.Read(ret); err != nil {
|
||||
panic(err)
|
||||
|
|
|
@ -23,10 +23,8 @@ func (envVar *EnvVar) set(key, val string) {
|
|||
envVar.Vars[key] = val
|
||||
}
|
||||
|
||||
var defaultConfig = Config{}
|
||||
|
||||
func New(config ...Config) fiber.Handler {
|
||||
var cfg = defaultConfig
|
||||
var cfg Config
|
||||
if len(config) > 0 {
|
||||
cfg = config[0]
|
||||
}
|
||||
|
@ -57,8 +55,9 @@ func newEnvVar(cfg Config) *EnvVar {
|
|||
}
|
||||
}
|
||||
} else {
|
||||
const numElems = 2
|
||||
for _, envVal := range os.Environ() {
|
||||
keyVal := strings.SplitN(envVal, "=", 2)
|
||||
keyVal := strings.SplitN(envVal, "=", numElems)
|
||||
if _, exists := cfg.ExcludeVars[keyVal[0]]; !exists {
|
||||
vars.set(keyVal[0], keyVal[1])
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
|
||||
package envvar
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
|
@ -12,16 +14,25 @@ import (
|
|||
)
|
||||
|
||||
func TestEnvVarStructWithExportVarsExcludeVars(t *testing.T) {
|
||||
os.Setenv("testKey", "testEnvValue")
|
||||
os.Setenv("anotherEnvKey", "anotherEnvVal")
|
||||
os.Setenv("excludeKey", "excludeEnvValue")
|
||||
defer os.Unsetenv("testKey")
|
||||
defer os.Unsetenv("anotherEnvKey")
|
||||
defer os.Unsetenv("excludeKey")
|
||||
err := os.Setenv("testKey", "testEnvValue")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
err = os.Setenv("anotherEnvKey", "anotherEnvVal")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
err = os.Setenv("excludeKey", "excludeEnvValue")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
defer func() {
|
||||
err := os.Unsetenv("testKey")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
err = os.Unsetenv("anotherEnvKey")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
err = os.Unsetenv("excludeKey")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
}()
|
||||
|
||||
vars := newEnvVar(Config{
|
||||
ExportVars: map[string]string{"testKey": "", "testDefaultKey": "testDefaultVal"},
|
||||
ExcludeVars: map[string]string{"excludeKey": ""}})
|
||||
ExcludeVars: map[string]string{"excludeKey": ""},
|
||||
})
|
||||
|
||||
utils.AssertEqual(t, vars.Vars["testKey"], "testEnvValue")
|
||||
utils.AssertEqual(t, vars.Vars["testDefaultKey"], "testDefaultVal")
|
||||
|
@ -30,21 +41,28 @@ func TestEnvVarStructWithExportVarsExcludeVars(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestEnvVarHandler(t *testing.T) {
|
||||
os.Setenv("testKey", "testVal")
|
||||
defer os.Unsetenv("testKey")
|
||||
err := os.Setenv("testKey", "testVal")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
defer func() {
|
||||
err := os.Unsetenv("testKey")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
}()
|
||||
|
||||
expectedEnvVarResponse, _ := json.Marshal(
|
||||
expectedEnvVarResponse, err := json.Marshal(
|
||||
struct {
|
||||
Vars map[string]string `json:"vars"`
|
||||
}{
|
||||
map[string]string{"testKey": "testVal"},
|
||||
})
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
app := fiber.New()
|
||||
app.Use("/envvars", New(Config{
|
||||
ExportVars: map[string]string{"testKey": ""}}))
|
||||
ExportVars: map[string]string{"testKey": ""},
|
||||
}))
|
||||
|
||||
req, _ := http.NewRequest("GET", "http://localhost/envvars", nil)
|
||||
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars", nil)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
|
@ -57,14 +75,16 @@ func TestEnvVarHandler(t *testing.T) {
|
|||
func TestEnvVarHandlerNotMatched(t *testing.T) {
|
||||
app := fiber.New()
|
||||
app.Use("/envvars", New(Config{
|
||||
ExportVars: map[string]string{"testKey": ""}}))
|
||||
ExportVars: map[string]string{"testKey": ""},
|
||||
}))
|
||||
|
||||
app.Get("/another-path", func(ctx *fiber.Ctx) error {
|
||||
utils.AssertEqual(t, nil, ctx.SendString("OK"))
|
||||
return nil
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "http://localhost/another-path", nil)
|
||||
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/another-path", nil)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
|
@ -75,13 +95,18 @@ func TestEnvVarHandlerNotMatched(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestEnvVarHandlerDefaultConfig(t *testing.T) {
|
||||
os.Setenv("testEnvKey", "testEnvVal")
|
||||
defer os.Unsetenv("testEnvKey")
|
||||
err := os.Setenv("testEnvKey", "testEnvVal")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
defer func() {
|
||||
err := os.Unsetenv("testEnvKey")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
}()
|
||||
|
||||
app := fiber.New()
|
||||
app.Use("/envvars", New())
|
||||
|
||||
req, _ := http.NewRequest("GET", "http://localhost/envvars", nil)
|
||||
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars", nil)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
|
@ -98,7 +123,8 @@ func TestEnvVarHandlerMethod(t *testing.T) {
|
|||
app := fiber.New()
|
||||
app.Use("/envvars", New())
|
||||
|
||||
req, _ := http.NewRequest("POST", "http://localhost/envvars", nil)
|
||||
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodPost, "http://localhost/envvars", nil)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusMethodNotAllowed, resp.StatusCode)
|
||||
|
@ -107,14 +133,19 @@ func TestEnvVarHandlerMethod(t *testing.T) {
|
|||
func TestEnvVarHandlerSpecialValue(t *testing.T) {
|
||||
testEnvKey := "testEnvKey"
|
||||
fakeBase64 := "testBase64:TQ=="
|
||||
os.Setenv(testEnvKey, fakeBase64)
|
||||
defer os.Unsetenv(testEnvKey)
|
||||
err := os.Setenv(testEnvKey, fakeBase64)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
defer func() {
|
||||
err := os.Unsetenv(testEnvKey)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
}()
|
||||
|
||||
app := fiber.New()
|
||||
app.Use("/envvars", New())
|
||||
app.Use("/envvars/export", New(Config{ExportVars: map[string]string{testEnvKey: ""}}))
|
||||
|
||||
req, _ := http.NewRequest("GET", "http://localhost/envvars", nil)
|
||||
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars", nil)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
|
@ -126,7 +157,8 @@ func TestEnvVarHandlerSpecialValue(t *testing.T) {
|
|||
val := envVars.Vars[testEnvKey]
|
||||
utils.AssertEqual(t, fakeBase64, val)
|
||||
|
||||
req, _ = http.NewRequest("GET", "http://localhost/envvars/export", nil)
|
||||
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars/export", nil)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
resp, err = app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
|
|
|
@ -23,6 +23,8 @@ type Config struct {
|
|||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
//
|
||||
//nolint:gochecknoglobals // Using a global var is fine here
|
||||
var ConfigDefault = Config{
|
||||
Weak: false,
|
||||
Next: nil,
|
||||
|
|
|
@ -5,12 +5,8 @@ import (
|
|||
"hash/crc32"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/valyala/bytebufferpool"
|
||||
)
|
||||
|
||||
var (
|
||||
normalizedHeaderETag = []byte("Etag")
|
||||
weakPrefix = []byte("W/")
|
||||
"github.com/valyala/bytebufferpool"
|
||||
)
|
||||
|
||||
// New creates a new middleware handler
|
||||
|
@ -18,32 +14,38 @@ func New(config ...Config) fiber.Handler {
|
|||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
crc32q := crc32.MakeTable(0xD5828281)
|
||||
var (
|
||||
normalizedHeaderETag = []byte("Etag")
|
||||
weakPrefix = []byte("W/")
|
||||
)
|
||||
|
||||
const crcPol = 0xD5828281
|
||||
crc32q := crc32.MakeTable(crcPol)
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) (err error) {
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Return err if next handler returns one
|
||||
if err = c.Next(); err != nil {
|
||||
return
|
||||
if err := c.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Don't generate ETags for invalid responses
|
||||
if c.Response().StatusCode() != fiber.StatusOK {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
body := c.Response().Body()
|
||||
// Skips ETag if no response body is present
|
||||
if len(body) == 0 {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
// Skip ETag if header is already present
|
||||
if c.Response().Header.PeekBytes(normalizedHeaderETag) != nil {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate ETag for response
|
||||
|
@ -52,14 +54,14 @@ func New(config ...Config) fiber.Handler {
|
|||
|
||||
// Enable weak tag
|
||||
if cfg.Weak {
|
||||
_, _ = bb.Write(weakPrefix)
|
||||
_, _ = bb.Write(weakPrefix) //nolint:errcheck // This will never fail
|
||||
}
|
||||
|
||||
_ = bb.WriteByte('"')
|
||||
_ = bb.WriteByte('"') //nolint:errcheck // This will never fail
|
||||
bb.B = appendUint(bb.Bytes(), uint32(len(body)))
|
||||
_ = bb.WriteByte('-')
|
||||
_ = bb.WriteByte('-') //nolint:errcheck // This will never fail
|
||||
bb.B = appendUint(bb.Bytes(), crc32.Checksum(body, crc32q))
|
||||
_ = bb.WriteByte('"')
|
||||
_ = bb.WriteByte('"') //nolint:errcheck // This will never fail
|
||||
|
||||
etag := bb.Bytes()
|
||||
|
||||
|
@ -78,7 +80,7 @@ func New(config ...Config) fiber.Handler {
|
|||
// W/1 != W/2 || W/1 != 2
|
||||
c.Response().Header.SetCanonical(normalizedHeaderETag, etag)
|
||||
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
if bytes.Contains(clientEtag, etag) {
|
||||
|
@ -90,7 +92,7 @@ func New(config ...Config) fiber.Handler {
|
|||
// 1 != 2
|
||||
c.Response().Header.SetCanonical(normalizedHeaderETag, etag)
|
||||
|
||||
return
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -102,7 +104,7 @@ func appendUint(dst []byte, n uint32) []byte {
|
|||
var q uint32
|
||||
for n >= 10 {
|
||||
i--
|
||||
q = n / 10
|
||||
q = n / 10 //nolint:gomnd // TODO: Explain why we divide by 10 here
|
||||
buf[i] = '0' + byte(n-q*10)
|
||||
n = q
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
|
@ -21,7 +22,7 @@ func Test_ETag_Next(t *testing.T) {
|
|||
},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
@ -37,7 +38,7 @@ func Test_ETag_SkipError(t *testing.T) {
|
|||
return fiber.ErrForbidden
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusForbidden, resp.StatusCode)
|
||||
}
|
||||
|
@ -53,7 +54,7 @@ func Test_ETag_NotStatusOK(t *testing.T) {
|
|||
return c.SendStatus(fiber.StatusCreated)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusCreated, resp.StatusCode)
|
||||
}
|
||||
|
@ -69,7 +70,7 @@ func Test_ETag_NoBody(t *testing.T) {
|
|||
return nil
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
@ -91,7 +92,7 @@ func Test_ETag_NewEtag(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func testETagNewEtag(t *testing.T, headerIfNoneMatch, matched bool) {
|
||||
func testETagNewEtag(t *testing.T, headerIfNoneMatch, matched bool) { //nolint:revive // We're in a test, so using bools as a flow-control is fine
|
||||
t.Helper()
|
||||
|
||||
app := fiber.New()
|
||||
|
@ -102,7 +103,7 @@ func testETagNewEtag(t *testing.T, headerIfNoneMatch, matched bool) {
|
|||
return c.SendString("Hello, World!")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
if headerIfNoneMatch {
|
||||
etag := `"non-match"`
|
||||
if matched {
|
||||
|
@ -145,7 +146,7 @@ func Test_ETag_WeakEtag(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func testETagWeakEtag(t *testing.T, headerIfNoneMatch, matched bool) {
|
||||
func testETagWeakEtag(t *testing.T, headerIfNoneMatch, matched bool) { //nolint:revive // We're in a test, so using bools as a flow-control is fine
|
||||
t.Helper()
|
||||
|
||||
app := fiber.New()
|
||||
|
@ -156,7 +157,7 @@ func testETagWeakEtag(t *testing.T, headerIfNoneMatch, matched bool) {
|
|||
return c.SendString("Hello, World!")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
if headerIfNoneMatch {
|
||||
etag := `W/"non-match"`
|
||||
if matched {
|
||||
|
@ -199,7 +200,7 @@ func Test_ETag_CustomEtag(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func testETagCustomEtag(t *testing.T, headerIfNoneMatch, matched bool) {
|
||||
func testETagCustomEtag(t *testing.T, headerIfNoneMatch, matched bool) { //nolint:revive // We're in a test, so using bools as a flow-control is fine
|
||||
t.Helper()
|
||||
|
||||
app := fiber.New()
|
||||
|
@ -214,7 +215,7 @@ func testETagCustomEtag(t *testing.T, headerIfNoneMatch, matched bool) {
|
|||
return c.SendString("Hello, World!")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
if headerIfNoneMatch {
|
||||
etag := `"non-match"`
|
||||
if matched {
|
||||
|
@ -255,7 +256,7 @@ func Test_ETag_CustomEtagPut(t *testing.T) {
|
|||
return c.SendString("Hello, World!")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("PUT", "/", nil)
|
||||
req := httptest.NewRequest(fiber.MethodPut, "/", nil)
|
||||
req.Header.Set(fiber.HeaderIfMatch, `"non-match"`)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
@ -275,7 +276,7 @@ func Benchmark_Etag(b *testing.B) {
|
|||
h := app.Handler()
|
||||
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
fctx.Request.Header.SetMethod("GET")
|
||||
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
fctx.Request.SetRequestURI("/")
|
||||
|
||||
b.ReportAllocs()
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package expvar
|
||||
|
||||
import "github.com/gofiber/fiber/v2"
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
|
@ -10,6 +12,7 @@ type Config struct {
|
|||
Next func(c *fiber.Ctx) bool
|
||||
}
|
||||
|
||||
//nolint:gochecknoglobals // Using a global var is fine here
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
||||
"github.com/valyala/fasthttp/expvarhandler"
|
||||
)
|
||||
|
||||
|
@ -29,6 +30,6 @@ func New(config ...Config) fiber.Handler {
|
|||
return nil
|
||||
}
|
||||
|
||||
return c.Redirect("/debug/vars", 302)
|
||||
return c.Redirect("/debug/vars", fiber.StatusFound)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,6 +34,8 @@ type Config struct {
|
|||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
//
|
||||
//nolint:gochecknoglobals // Using a global var is fine here
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
File: "",
|
||||
|
|
|
@ -1,16 +1,18 @@
|
|||
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
|
||||
package favicon
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// go test -run Test_Middleware_Favicon
|
||||
|
@ -25,22 +27,22 @@ func Test_Middleware_Favicon(t *testing.T) {
|
|||
})
|
||||
|
||||
// Skip Favicon middleware
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest("GET", "/favicon.ico", nil))
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, fiber.StatusNoContent, resp.StatusCode, "Status code")
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest("OPTIONS", "/favicon.ico", nil))
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodOptions, "/favicon.ico", nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest("PUT", "/favicon.ico", nil))
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodPut, "/favicon.ico", nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, fiber.StatusMethodNotAllowed, resp.StatusCode, "Status code")
|
||||
utils.AssertEqual(t, "GET, HEAD, OPTIONS", resp.Header.Get(fiber.HeaderAllow))
|
||||
utils.AssertEqual(t, strings.Join([]string{fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions}, ", "), resp.Header.Get(fiber.HeaderAllow))
|
||||
}
|
||||
|
||||
// go test -run Test_Middleware_Favicon_Not_Found
|
||||
|
@ -70,8 +72,7 @@ func Test_Middleware_Favicon_Found(t *testing.T) {
|
|||
return nil
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/favicon.ico", nil))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
||||
utils.AssertEqual(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType))
|
||||
|
@ -83,15 +84,15 @@ func Test_Middleware_Favicon_Found(t *testing.T) {
|
|||
// TODO use os.Dir if fiber upgrades to 1.16
|
||||
type mockFS struct{}
|
||||
|
||||
func (m mockFS) Open(name string) (http.File, error) {
|
||||
func (mockFS) Open(name string) (http.File, error) {
|
||||
if name == "/" {
|
||||
name = "."
|
||||
} else {
|
||||
name = strings.TrimPrefix(name, "/")
|
||||
}
|
||||
file, err := os.Open(name)
|
||||
file, err := os.Open(name) //nolint:gosec // We're in a test func, so this is fine
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to open: %w", err)
|
||||
}
|
||||
return file, nil
|
||||
}
|
||||
|
@ -106,7 +107,7 @@ func Test_Middleware_Favicon_FileSystem(t *testing.T) {
|
|||
FileSystem: mockFS{},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/favicon.ico", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
||||
utils.AssertEqual(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType))
|
||||
|
@ -123,7 +124,7 @@ func Test_Middleware_Favicon_CacheControl(t *testing.T) {
|
|||
File: "../../.github/testdata/favicon.ico",
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/favicon.ico", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
||||
utils.AssertEqual(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType))
|
||||
|
@ -159,7 +160,7 @@ func Test_Favicon_Next(t *testing.T) {
|
|||
},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package filesystem
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
|
@ -55,6 +56,8 @@ type Config struct {
|
|||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
//
|
||||
//nolint:gochecknoglobals // Using a global var is fine here
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
Root: nil,
|
||||
|
@ -102,7 +105,7 @@ func New(config ...Config) fiber.Handler {
|
|||
cacheControlStr := "public, max-age=" + strconv.Itoa(cfg.MaxAge)
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) (err error) {
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
|
@ -131,28 +134,23 @@ func New(config ...Config) fiber.Handler {
|
|||
path = cfg.PathPrefix + path
|
||||
}
|
||||
|
||||
var (
|
||||
file http.File
|
||||
stat os.FileInfo
|
||||
)
|
||||
|
||||
if len(path) > 1 {
|
||||
path = utils.TrimRight(path, '/')
|
||||
}
|
||||
file, err = cfg.Root.Open(path)
|
||||
file, err := cfg.Root.Open(path)
|
||||
if err != nil && os.IsNotExist(err) && cfg.NotFoundFile != "" {
|
||||
file, err = cfg.Root.Open(cfg.NotFoundFile)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return c.Status(fiber.StatusNotFound).Next()
|
||||
}
|
||||
return
|
||||
return fmt.Errorf("failed to open: %w", err)
|
||||
}
|
||||
|
||||
if stat, err = file.Stat(); err != nil {
|
||||
return
|
||||
stat, err := file.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to stat: %w", err)
|
||||
}
|
||||
|
||||
// Serve index if path is directory
|
||||
|
@ -200,7 +198,7 @@ func New(config ...Config) fiber.Handler {
|
|||
c.Response().SkipBody = true
|
||||
c.Response().Header.SetContentLength(contentLength)
|
||||
if err := file.Close(); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to close: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -210,22 +208,18 @@ func New(config ...Config) fiber.Handler {
|
|||
}
|
||||
|
||||
// SendFile ...
|
||||
func SendFile(c *fiber.Ctx, fs http.FileSystem, path string) (err error) {
|
||||
var (
|
||||
file http.File
|
||||
stat os.FileInfo
|
||||
)
|
||||
|
||||
file, err = fs.Open(path)
|
||||
func SendFile(c *fiber.Ctx, fs http.FileSystem, path string) error {
|
||||
file, err := fs.Open(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return fiber.ErrNotFound
|
||||
}
|
||||
return err
|
||||
return fmt.Errorf("failed to open: %w", err)
|
||||
}
|
||||
|
||||
if stat, err = file.Stat(); err != nil {
|
||||
return err
|
||||
stat, err := file.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to stat: %w", err)
|
||||
}
|
||||
|
||||
// Serve index if path is directory
|
||||
|
@ -268,7 +262,7 @@ func SendFile(c *fiber.Ctx, fs http.FileSystem, path string) (err error) {
|
|||
c.Response().SkipBody = true
|
||||
c.Response().Header.SetContentLength(contentLength)
|
||||
if err := file.Close(); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to close: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
@ -119,7 +121,7 @@ func Test_FileSystem(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
resp, err := app.Test(httptest.NewRequest("GET", tt.url, nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, tt.url, nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, tt.statusCode, resp.StatusCode)
|
||||
|
||||
|
@ -142,7 +144,7 @@ func Test_FileSystem_Next(t *testing.T) {
|
|||
},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
@ -168,7 +170,8 @@ func Test_FileSystem_Head(t *testing.T) {
|
|||
Root: http.Dir("../../.github/testdata/fs"),
|
||||
}))
|
||||
|
||||
req, _ := http.NewRequest(fiber.MethodHead, "/test", nil)
|
||||
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodHead, "/test", nil)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
@ -182,7 +185,8 @@ func Test_FileSystem_NoRoot(t *testing.T) {
|
|||
|
||||
app := fiber.New()
|
||||
app.Use(New())
|
||||
_, _ = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
}
|
||||
|
||||
func Test_FileSystem_UsingParam(t *testing.T) {
|
||||
|
@ -193,7 +197,8 @@ func Test_FileSystem_UsingParam(t *testing.T) {
|
|||
return SendFile(c, http.Dir("../../.github/testdata/fs"), c.Params("path")+".html")
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest(fiber.MethodHead, "/index", nil)
|
||||
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodHead, "/index", nil)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
@ -207,7 +212,8 @@ func Test_FileSystem_UsingParam_NonFile(t *testing.T) {
|
|||
return SendFile(c, http.Dir("../../.github/testdata/fs"), c.Params("path")+".html")
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest(fiber.MethodHead, "/template", nil)
|
||||
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodHead, "/template", nil)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 404, resp.StatusCode)
|
||||
|
|
|
@ -13,18 +13,18 @@ import (
|
|||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
func getFileExtension(path string) string {
|
||||
n := strings.LastIndexByte(path, '.')
|
||||
func getFileExtension(p string) string {
|
||||
n := strings.LastIndexByte(p, '.')
|
||||
if n < 0 {
|
||||
return ""
|
||||
}
|
||||
return path[n:]
|
||||
return p[n:]
|
||||
}
|
||||
|
||||
func dirList(c *fiber.Ctx, f http.File) error {
|
||||
fileinfos, err := f.Readdir(-1)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to read dir: %w", err)
|
||||
}
|
||||
|
||||
fm := make(map[string]os.FileInfo, len(fileinfos))
|
||||
|
@ -36,13 +36,13 @@ func dirList(c *fiber.Ctx, f http.File) error {
|
|||
}
|
||||
|
||||
basePathEscaped := html.EscapeString(c.Path())
|
||||
fmt.Fprintf(c, "<html><head><title>%s</title><style>.dir { font-weight: bold }</style></head><body>", basePathEscaped)
|
||||
fmt.Fprintf(c, "<h1>%s</h1>", basePathEscaped)
|
||||
fmt.Fprint(c, "<ul>")
|
||||
_, _ = fmt.Fprintf(c, "<html><head><title>%s</title><style>.dir { font-weight: bold }</style></head><body>", basePathEscaped)
|
||||
_, _ = fmt.Fprintf(c, "<h1>%s</h1>", basePathEscaped)
|
||||
_, _ = fmt.Fprint(c, "<ul>")
|
||||
|
||||
if len(basePathEscaped) > 1 {
|
||||
parentPathEscaped := html.EscapeString(utils.TrimRight(c.Path(), '/') + "/..")
|
||||
fmt.Fprintf(c, `<li><a href="%s" class="dir">..</a></li>`, parentPathEscaped)
|
||||
_, _ = fmt.Fprintf(c, `<li><a href="%s" class="dir">..</a></li>`, parentPathEscaped)
|
||||
}
|
||||
|
||||
sort.Strings(filenames)
|
||||
|
@ -55,10 +55,10 @@ func dirList(c *fiber.Ctx, f http.File) error {
|
|||
auxStr = fmt.Sprintf("file, %d bytes", fi.Size())
|
||||
className = "file"
|
||||
}
|
||||
fmt.Fprintf(c, `<li><a href="%s" class="%s">%s</a>, %s, last modified %s</li>`,
|
||||
_, _ = fmt.Fprintf(c, `<li><a href="%s" class="%s">%s</a>, %s, last modified %s</li>`,
|
||||
pathEscaped, className, html.EscapeString(name), auxStr, fi.ModTime())
|
||||
}
|
||||
fmt.Fprint(c, "</ul></body></html>")
|
||||
_, _ = fmt.Fprint(c, "</ul></body></html>")
|
||||
|
||||
c.Type("html")
|
||||
|
||||
|
|
|
@ -9,9 +9,7 @@ import (
|
|||
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidIdempotencyKey = errors.New("invalid idempotency key")
|
||||
)
|
||||
var ErrInvalidIdempotencyKey = errors.New("invalid idempotency key")
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
|
@ -51,13 +49,15 @@ type Config struct {
|
|||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
//
|
||||
//nolint:gochecknoglobals // Using a global var is fine here
|
||||
var ConfigDefault = Config{
|
||||
Next: func(c *fiber.Ctx) bool {
|
||||
// Skip middleware if the request was done using a safe HTTP method
|
||||
return fiber.IsMethodSafe(c.Method())
|
||||
},
|
||||
|
||||
Lifetime: 30 * time.Minute,
|
||||
Lifetime: 30 * time.Minute, //nolint:gomnd // No magic number, just the default config
|
||||
|
||||
KeyHeader: "X-Idempotency-Key",
|
||||
KeyHeaderValidate: func(k string) error {
|
||||
|
@ -112,7 +112,7 @@ func configDefault(config ...Config) Config {
|
|||
|
||||
if cfg.Storage == nil {
|
||||
cfg.Storage = memory.New(memory.Config{
|
||||
GCInterval: cfg.Lifetime / 2,
|
||||
GCInterval: cfg.Lifetime / 2, //nolint:gomnd // Half the lifetime interval
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
|
||||
package idempotency_test
|
||||
|
||||
import (
|
||||
|
@ -14,6 +15,7 @@ import (
|
|||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/idempotency"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
|
@ -172,5 +174,4 @@ func Benchmark_Idempotency(b *testing.B) {
|
|||
h(c)
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package limiter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
@ -58,19 +58,21 @@ type Config struct {
|
|||
// Default: a new Fixed Window Rate Limiter
|
||||
LimiterMiddleware LimiterHandler
|
||||
|
||||
// DEPRECATED: Use Expiration instead
|
||||
// Deprecated: Use Expiration instead
|
||||
Duration time.Duration
|
||||
|
||||
// DEPRECATED, use Storage instead
|
||||
// Deprecated: Use Storage instead
|
||||
Store fiber.Storage
|
||||
|
||||
// DEPRECATED, use KeyGenerator instead
|
||||
// Deprecated: Use KeyGenerator instead
|
||||
Key func(*fiber.Ctx) string
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
//
|
||||
//nolint:gochecknoglobals // Using a global var is fine here
|
||||
var ConfigDefault = Config{
|
||||
Max: 5,
|
||||
Max: 5, //nolint:gomnd // No magic number, just the default config
|
||||
Expiration: 1 * time.Minute,
|
||||
KeyGenerator: func(c *fiber.Ctx) string {
|
||||
return c.IP()
|
||||
|
@ -95,15 +97,15 @@ func configDefault(config ...Config) Config {
|
|||
|
||||
// Set default values
|
||||
if int(cfg.Duration.Seconds()) > 0 {
|
||||
fmt.Println("[LIMITER] Duration is deprecated, please use Expiration")
|
||||
log.Printf("[LIMITER] Duration is deprecated, please use Expiration\n")
|
||||
cfg.Expiration = cfg.Duration
|
||||
}
|
||||
if cfg.Key != nil {
|
||||
fmt.Println("[LIMITER] Key is deprecated, please us KeyGenerator")
|
||||
log.Printf("[LIMITER] Key is deprecated, please us KeyGenerator\n")
|
||||
cfg.KeyGenerator = cfg.Key
|
||||
}
|
||||
if cfg.Store != nil {
|
||||
fmt.Println("[LIMITER] Store is deprecated, please use Storage")
|
||||
log.Printf("[LIMITER] Store is deprecated, please use Storage\n")
|
||||
cfg.Storage = cfg.Store
|
||||
}
|
||||
if cfg.Next == nil {
|
||||
|
|
|
@ -2,7 +2,6 @@ package limiter
|
|||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
|
@ -11,6 +10,7 @@ import (
|
|||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
|
@ -34,7 +34,7 @@ func Test_Limiter_Concurrency_Store(t *testing.T) {
|
|||
var wg sync.WaitGroup
|
||||
singleRequest := func(wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
|
@ -50,13 +50,13 @@ func Test_Limiter_Concurrency_Store(t *testing.T) {
|
|||
|
||||
wg.Wait()
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
@ -80,7 +80,7 @@ func Test_Limiter_Concurrency(t *testing.T) {
|
|||
var wg sync.WaitGroup
|
||||
singleRequest := func(wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
|
@ -96,13 +96,13 @@ func Test_Limiter_Concurrency(t *testing.T) {
|
|||
|
||||
wg.Wait()
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
@ -120,21 +120,21 @@ func Test_Limiter_No_Skip_Choices(t *testing.T) {
|
|||
}))
|
||||
|
||||
app.Get("/:status", func(c *fiber.Ctx) error {
|
||||
if c.Params("status") == "fail" {
|
||||
if c.Params("status") == "fail" { //nolint:goconst // False positive
|
||||
return c.SendStatus(400)
|
||||
}
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
}
|
||||
|
@ -157,21 +157,21 @@ func Test_Limiter_Skip_Failed_Requests(t *testing.T) {
|
|||
return c.SendStatus(200)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
@ -196,21 +196,21 @@ func Test_Limiter_Skip_Successful_Requests(t *testing.T) {
|
|||
return c.SendStatus(200)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil))
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil))
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil))
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
}
|
||||
|
@ -232,7 +232,7 @@ func Benchmark_Limiter_Custom_Store(b *testing.B) {
|
|||
h := app.Handler()
|
||||
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
fctx.Request.Header.SetMethod("GET")
|
||||
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
fctx.Request.SetRequestURI("/")
|
||||
|
||||
b.ResetTimer()
|
||||
|
@ -252,7 +252,7 @@ func Test_Limiter_Next(t *testing.T) {
|
|||
},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
@ -271,7 +271,7 @@ func Test_Limiter_Headers(t *testing.T) {
|
|||
})
|
||||
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
fctx.Request.Header.SetMethod("GET")
|
||||
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
fctx.Request.SetRequestURI("/")
|
||||
|
||||
app.Handler()(fctx)
|
||||
|
@ -301,7 +301,7 @@ func Benchmark_Limiter(b *testing.B) {
|
|||
h := app.Handler()
|
||||
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
fctx.Request.Header.SetMethod("GET")
|
||||
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
fctx.Request.SetRequestURI("/")
|
||||
|
||||
b.ResetTimer()
|
||||
|
@ -327,7 +327,7 @@ func Test_Sliding_Window(t *testing.T) {
|
|||
})
|
||||
|
||||
singleRequest := func(shouldFail bool) {
|
||||
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
if shouldFail {
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
|
|
@ -46,7 +46,7 @@ func newManager(storage fiber.Storage) *manager {
|
|||
|
||||
// acquire returns an *entry from the sync.Pool
|
||||
func (m *manager) acquire() *item {
|
||||
return m.pool.Get().(*item)
|
||||
return m.pool.Get().(*item) //nolint:forcetypeassert // We store nothing else in the pool
|
||||
}
|
||||
|
||||
// release and reset *entry to sync.Pool
|
||||
|
@ -58,37 +58,33 @@ func (m *manager) release(e *item) {
|
|||
}
|
||||
|
||||
// get data from storage or memory
|
||||
func (m *manager) get(key string) (it *item) {
|
||||
func (m *manager) get(key string) *item {
|
||||
var it *item
|
||||
if m.storage != nil {
|
||||
it = m.acquire()
|
||||
if raw, _ := m.storage.Get(key); raw != nil {
|
||||
raw, err := m.storage.Get(key)
|
||||
if err != nil {
|
||||
return it
|
||||
}
|
||||
if raw != nil {
|
||||
if _, err := it.UnmarshalMsg(raw); err != nil {
|
||||
return
|
||||
return it
|
||||
}
|
||||
}
|
||||
return
|
||||
return it
|
||||
}
|
||||
if it, _ = m.memory.Get(key).(*item); it == nil {
|
||||
if it, _ = m.memory.Get(key).(*item); it == nil { //nolint:errcheck // We store nothing else in the pool
|
||||
it = m.acquire()
|
||||
return it
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// get raw data from storage or memory
|
||||
func (m *manager) getRaw(key string) (raw []byte) {
|
||||
if m.storage != nil {
|
||||
raw, _ = m.storage.Get(key)
|
||||
} else {
|
||||
raw, _ = m.memory.Get(key).([]byte)
|
||||
}
|
||||
return
|
||||
return it
|
||||
}
|
||||
|
||||
// set data to storage or memory
|
||||
func (m *manager) set(key string, it *item, exp time.Duration) {
|
||||
if m.storage != nil {
|
||||
if raw, err := it.MarshalMsg(nil); err == nil {
|
||||
_ = m.storage.Set(key, raw, exp)
|
||||
_ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Handle error here
|
||||
}
|
||||
// we can release data because it's serialized to database
|
||||
m.release(it)
|
||||
|
@ -96,21 +92,3 @@ func (m *manager) set(key string, it *item, exp time.Duration) {
|
|||
m.memory.Set(key, it, exp)
|
||||
}
|
||||
}
|
||||
|
||||
// set data to storage or memory
|
||||
func (m *manager) setRaw(key string, raw []byte, exp time.Duration) {
|
||||
if m.storage != nil {
|
||||
_ = m.storage.Set(key, raw, exp)
|
||||
} else {
|
||||
m.memory.Set(key, raw, exp)
|
||||
}
|
||||
}
|
||||
|
||||
// delete data from storage or memory
|
||||
func (m *manager) delete(key string) {
|
||||
if m.storage != nil {
|
||||
_ = m.storage.Delete(key)
|
||||
} else {
|
||||
m.memory.Delete(key)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -95,7 +95,7 @@ app.Use(logger.New(logger.Config{
|
|||
TimeZone: "Asia/Shanghai",
|
||||
Done: func(c *fiber.Ctx, logString []byte) {
|
||||
if c.Response().StatusCode() != fiber.StatusOK {
|
||||
reporter.SendToSlack(logString)
|
||||
reporter.SendToSlack(logString)
|
||||
}
|
||||
},
|
||||
}))
|
||||
|
@ -189,7 +189,7 @@ const (
|
|||
TagBytesReceived = "bytesReceived"
|
||||
TagRoute = "route"
|
||||
TagError = "error"
|
||||
// DEPRECATED: Use TagReqHeader instead
|
||||
// Deprecated: Use TagReqHeader instead
|
||||
TagHeader = "header:" // request header
|
||||
TagReqHeader = "reqHeader:" // request header
|
||||
TagRespHeader = "respHeader:" // response header
|
||||
|
|
|
@ -79,13 +79,15 @@ type Buffer interface {
|
|||
type LogFunc func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error)
|
||||
|
||||
// ConfigDefault is the default config
|
||||
//
|
||||
//nolint:gochecknoglobals // Using a global var is fine here
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
Done: nil,
|
||||
Format: "[${time}] ${status} - ${latency} ${method} ${path}\n",
|
||||
TimeFormat: "15:04:05",
|
||||
TimeZone: "Local",
|
||||
TimeInterval: 500 * time.Millisecond,
|
||||
TimeInterval: 500 * time.Millisecond, //nolint:gomnd // No magic number, just the default config
|
||||
Output: os.Stdout,
|
||||
enableColors: true,
|
||||
}
|
||||
|
|
|
@ -1,13 +1,10 @@
|
|||
package logger
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
var DataPool = sync.Pool{New: func() interface{} { return new(Data) }}
|
||||
|
||||
// Data is a struct to define some variables to use in custom logger function.
|
||||
type Data struct {
|
||||
Pid string
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/mattn/go-colorable"
|
||||
"github.com/mattn/go-isatty"
|
||||
"github.com/valyala/bytebufferpool"
|
||||
|
@ -55,6 +56,8 @@ func New(config ...Config) fiber.Handler {
|
|||
once sync.Once
|
||||
mu sync.Mutex
|
||||
errHandler fiber.ErrorHandler
|
||||
|
||||
dataPool = sync.Pool{New: func() interface{} { return new(Data) }}
|
||||
)
|
||||
|
||||
// If colors are enabled, check terminal compatibility
|
||||
|
@ -75,7 +78,7 @@ func New(config ...Config) fiber.Handler {
|
|||
}
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) (err error) {
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
|
@ -101,13 +104,13 @@ func New(config ...Config) fiber.Handler {
|
|||
})
|
||||
|
||||
// Logger data
|
||||
data := DataPool.Get().(*Data)
|
||||
data := dataPool.Get().(*Data) //nolint:forcetypeassert,errcheck // We store nothing else in the pool
|
||||
// no need for a reset, as long as we always override everything
|
||||
data.Pid = pid
|
||||
data.ErrPaddingStr = errPaddingStr
|
||||
data.Timestamp = timestamp
|
||||
// put data back in the pool
|
||||
defer DataPool.Put(data)
|
||||
defer dataPool.Put(data)
|
||||
|
||||
// Set latency start time
|
||||
if cfg.enableLatency {
|
||||
|
@ -121,7 +124,7 @@ func New(config ...Config) fiber.Handler {
|
|||
// Manually call error handler
|
||||
if chainErr != nil {
|
||||
if err := errHandler(c, chainErr); err != nil {
|
||||
_ = c.SendStatus(fiber.StatusInternalServerError)
|
||||
_ = c.SendStatus(fiber.StatusInternalServerError) //nolint:errcheck // TODO: Explain why we ignore the error here
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -142,18 +145,20 @@ func New(config ...Config) fiber.Handler {
|
|||
}
|
||||
|
||||
// Format log to buffer
|
||||
_, _ = buf.WriteString(fmt.Sprintf("%s |%s %3d %s| %7v | %15s |%s %-7s %s| %-"+errPaddingStr+"s %s\n",
|
||||
timestamp.Load().(string),
|
||||
statusColor(c.Response().StatusCode(), colors), c.Response().StatusCode(), colors.Reset,
|
||||
data.Stop.Sub(data.Start).Round(time.Millisecond),
|
||||
c.IP(),
|
||||
methodColor(c.Method(), colors), c.Method(), colors.Reset,
|
||||
c.Path(),
|
||||
formatErr,
|
||||
))
|
||||
_, _ = buf.WriteString( //nolint:errcheck // This will never fail
|
||||
fmt.Sprintf("%s |%s %3d %s| %7v | %15s |%s %-7s %s| %-"+errPaddingStr+"s %s\n",
|
||||
timestamp.Load().(string),
|
||||
statusColor(c.Response().StatusCode(), colors), c.Response().StatusCode(), colors.Reset,
|
||||
data.Stop.Sub(data.Start).Round(time.Millisecond),
|
||||
c.IP(),
|
||||
methodColor(c.Method(), colors), c.Method(), colors.Reset,
|
||||
c.Path(),
|
||||
formatErr,
|
||||
),
|
||||
)
|
||||
|
||||
// Write buffer to output
|
||||
_, _ = cfg.Output.Write(buf.Bytes())
|
||||
_, _ = cfg.Output.Write(buf.Bytes()) //nolint:errcheck // This will never fail
|
||||
|
||||
if cfg.Done != nil {
|
||||
cfg.Done(c, buf.Bytes())
|
||||
|
@ -169,7 +174,7 @@ func New(config ...Config) fiber.Handler {
|
|||
// Loop over template parts execute dynamic parts and add fixed parts to the buffer
|
||||
for i, logFunc := range logFunChain {
|
||||
if logFunc == nil {
|
||||
_, _ = buf.Write(templateChain[i])
|
||||
_, _ = buf.Write(templateChain[i]) //nolint:errcheck // This will never fail
|
||||
} else if templateChain[i] == nil {
|
||||
_, err = logFunc(buf, c, data, "")
|
||||
} else {
|
||||
|
@ -182,7 +187,7 @@ func New(config ...Config) fiber.Handler {
|
|||
|
||||
// Also write errors to the buffer
|
||||
if err != nil {
|
||||
_, _ = buf.WriteString(err.Error())
|
||||
_, _ = buf.WriteString(err.Error()) //nolint:errcheck // This will never fail
|
||||
}
|
||||
mu.Lock()
|
||||
// Write buffer to output
|
||||
|
@ -190,7 +195,7 @@ func New(config ...Config) fiber.Handler {
|
|||
// Write error to output
|
||||
if _, err := cfg.Output.Write([]byte(err.Error())); err != nil {
|
||||
// There is something wrong with the given io.Writer
|
||||
fmt.Fprintf(os.Stderr, "Failed to write to log, %v\n", err)
|
||||
_, _ = fmt.Fprintf(os.Stderr, "Failed to write to log, %v\n", err)
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
|
||||
package logger
|
||||
|
||||
import (
|
||||
|
@ -16,6 +17,7 @@ import (
|
|||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/requestid"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/bytebufferpool"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
@ -37,7 +39,7 @@ func Test_Logger(t *testing.T) {
|
|||
return errors.New("some random error")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
|
||||
utils.AssertEqual(t, "some random error", buf.String())
|
||||
|
@ -70,21 +72,21 @@ func Test_Logger_locals(t *testing.T) {
|
|||
return c.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
utils.AssertEqual(t, "johndoe", buf.String())
|
||||
|
||||
buf.Reset()
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest("GET", "/int", nil))
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/int", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
utils.AssertEqual(t, "55", buf.String())
|
||||
|
||||
buf.Reset()
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest("GET", "/empty", nil))
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/empty", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
utils.AssertEqual(t, "", buf.String())
|
||||
|
@ -100,7 +102,7 @@ func Test_Logger_Next(t *testing.T) {
|
|||
},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
@ -113,15 +115,15 @@ func Test_Logger_Done(t *testing.T) {
|
|||
app.Use(New(Config{
|
||||
Done: func(c *fiber.Ctx, logString []byte) {
|
||||
if c.Response().StatusCode() == fiber.StatusOK {
|
||||
buf.Write(logString)
|
||||
_, err := buf.Write(logString)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
}
|
||||
},
|
||||
})).Get("/logging", func(ctx *fiber.Ctx) error {
|
||||
return ctx.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/logging", nil))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/logging", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
utils.AssertEqual(t, true, buf.Len() > 0)
|
||||
|
@ -135,7 +137,7 @@ func Test_Logger_ErrorTimeZone(t *testing.T) {
|
|||
TimeZone: "invalid",
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
@ -156,7 +158,7 @@ func Test_Logger_ErrorOutput(t *testing.T) {
|
|||
Output: o,
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
|
||||
|
@ -178,7 +180,7 @@ func Test_Logger_All(t *testing.T) {
|
|||
// Alias colors
|
||||
colors := app.Config().ColorScheme
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/?foo=bar", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?foo=bar", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
|
||||
|
@ -198,7 +200,7 @@ func Test_Query_Params(t *testing.T) {
|
|||
Output: buf,
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/?foo=bar&baz=moz", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?foo=bar&baz=moz", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
|
||||
|
@ -226,7 +228,7 @@ func Test_Response_Body(t *testing.T) {
|
|||
return c.Send([]byte("Post in test"))
|
||||
})
|
||||
|
||||
_, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
expectedGetResponse := "Sample response body"
|
||||
|
@ -234,7 +236,7 @@ func Test_Response_Body(t *testing.T) {
|
|||
|
||||
buf.Reset() // Reset buffer to test POST
|
||||
|
||||
_, err = app.Test(httptest.NewRequest("POST", "/test", nil))
|
||||
_, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/test", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
expectedPostResponse := "Post in test"
|
||||
|
@ -258,7 +260,7 @@ func Test_Logger_AppendUint(t *testing.T) {
|
|||
return c.SendString("hello")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
utils.AssertEqual(t, "0 5 200", buf.String())
|
||||
|
@ -285,12 +287,11 @@ func Test_Logger_Data_Race(t *testing.T) {
|
|||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
resp1, err1 = app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp1, err1 = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
wg.Done()
|
||||
}()
|
||||
resp2, err2 = app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp2, err2 = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
wg.Wait()
|
||||
|
||||
utils.AssertEqual(t, nil, err1)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp1.StatusCode)
|
||||
utils.AssertEqual(t, nil, err2)
|
||||
|
@ -299,21 +300,23 @@ func Test_Logger_Data_Race(t *testing.T) {
|
|||
|
||||
// go test -v -run=^$ -bench=Benchmark_Logger -benchmem -count=4
|
||||
func Benchmark_Logger(b *testing.B) {
|
||||
benchSetup := func(bb *testing.B, app *fiber.App) {
|
||||
benchSetup := func(b *testing.B, app *fiber.App) {
|
||||
b.Helper()
|
||||
|
||||
h := app.Handler()
|
||||
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
fctx.Request.Header.SetMethod("GET")
|
||||
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
fctx.Request.SetRequestURI("/")
|
||||
|
||||
bb.ReportAllocs()
|
||||
bb.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for n := 0; n < bb.N; n++ {
|
||||
for n := 0; n < b.N; n++ {
|
||||
h(fctx)
|
||||
}
|
||||
|
||||
utils.AssertEqual(bb, 200, fctx.Response.Header.StatusCode())
|
||||
utils.AssertEqual(b, 200, fctx.Response.Header.StatusCode())
|
||||
}
|
||||
|
||||
b.Run("Base", func(bb *testing.B) {
|
||||
|
@ -375,8 +378,7 @@ func Test_Response_Header(t *testing.T) {
|
|||
return c.SendString("Hello fiber!")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
utils.AssertEqual(t, "Hello fiber!", buf.String())
|
||||
|
@ -396,10 +398,10 @@ func Test_Req_Header(t *testing.T) {
|
|||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello fiber!")
|
||||
})
|
||||
headerReq := httptest.NewRequest("GET", "/", nil)
|
||||
headerReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
headerReq.Header.Add("test", "Hello fiber!")
|
||||
resp, err := app.Test(headerReq)
|
||||
|
||||
resp, err := app.Test(headerReq)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
utils.AssertEqual(t, "Hello fiber!", buf.String())
|
||||
|
@ -419,10 +421,10 @@ func Test_ReqHeader_Header(t *testing.T) {
|
|||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello fiber!")
|
||||
})
|
||||
reqHeaderReq := httptest.NewRequest("GET", "/", nil)
|
||||
reqHeaderReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
reqHeaderReq.Header.Add("test", "Hello fiber!")
|
||||
resp, err := app.Test(reqHeaderReq)
|
||||
|
||||
resp, err := app.Test(reqHeaderReq)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
utils.AssertEqual(t, "Hello fiber!", buf.String())
|
||||
|
@ -449,10 +451,10 @@ func Test_CustomTags(t *testing.T) {
|
|||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello fiber!")
|
||||
})
|
||||
reqHeaderReq := httptest.NewRequest("GET", "/", nil)
|
||||
reqHeaderReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
reqHeaderReq.Header.Add("test", "Hello fiber!")
|
||||
resp, err := app.Test(reqHeaderReq)
|
||||
|
||||
resp, err := app.Test(reqHeaderReq)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
utils.AssertEqual(t, customTag, buf.String())
|
||||
|
@ -492,7 +494,7 @@ func Test_Logger_ByteSent_Streaming(t *testing.T) {
|
|||
return nil
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
utils.AssertEqual(t, "0 0 200", buf.String())
|
||||
|
|
|
@ -31,7 +31,7 @@ const (
|
|||
TagBytesReceived = "bytesReceived"
|
||||
TagRoute = "route"
|
||||
TagError = "error"
|
||||
// DEPRECATED: Use TagReqHeader instead
|
||||
// Deprecated: Use TagReqHeader instead
|
||||
TagHeader = "header:"
|
||||
TagReqHeader = "reqHeader:"
|
||||
TagRespHeader = "respHeader:"
|
||||
|
@ -195,7 +195,7 @@ func createTagMap(cfg *Config) map[string]LogFunc {
|
|||
return output.WriteString(fmt.Sprintf("%7v", latency))
|
||||
},
|
||||
TagTime: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(data.Timestamp.Load().(string))
|
||||
return output.WriteString(data.Timestamp.Load().(string)) //nolint:forcetypeassert // We always store a string in here
|
||||
},
|
||||
}
|
||||
// merge with custom tags from user
|
||||
|
|
|
@ -14,13 +14,16 @@ import (
|
|||
// funcChain contains for the parts which exist the functions for the dynamic parts
|
||||
// funcChain and fixParts always have the same length and contain nil for the parts where no data is required in the chain,
|
||||
// if a function exists for the part, a parameter for it can also exist in the fixParts slice
|
||||
func buildLogFuncChain(cfg *Config, tagFunctions map[string]LogFunc) (fixParts [][]byte, funcChain []LogFunc, err error) {
|
||||
func buildLogFuncChain(cfg *Config, tagFunctions map[string]LogFunc) ([][]byte, []LogFunc, error) {
|
||||
// process flow is copied from the fasttemplate flow https://github.com/valyala/fasttemplate/blob/2a2d1afadadf9715bfa19683cdaeac8347e5d9f9/template.go#L23-L62
|
||||
templateB := utils.UnsafeBytes(cfg.Format)
|
||||
startTagB := utils.UnsafeBytes(startTag)
|
||||
endTagB := utils.UnsafeBytes(endTag)
|
||||
paramSeparatorB := utils.UnsafeBytes(paramSeparator)
|
||||
|
||||
var fixParts [][]byte
|
||||
var funcChain []LogFunc
|
||||
|
||||
for {
|
||||
currentPos := bytes.Index(templateB, startTagB)
|
||||
if currentPos < 0 {
|
||||
|
@ -42,13 +45,13 @@ func buildLogFuncChain(cfg *Config, tagFunctions map[string]LogFunc) (fixParts [
|
|||
// ## function block ##
|
||||
// first check for tags with parameters
|
||||
if index := bytes.Index(templateB[:currentPos], paramSeparatorB); index != -1 {
|
||||
if logFunc, ok := tagFunctions[utils.UnsafeString(templateB[:index+1])]; ok {
|
||||
funcChain = append(funcChain, logFunc)
|
||||
// add param to the fixParts
|
||||
fixParts = append(fixParts, templateB[index+1:currentPos])
|
||||
} else {
|
||||
logFunc, ok := tagFunctions[utils.UnsafeString(templateB[:index+1])]
|
||||
if !ok {
|
||||
return nil, nil, errors.New("No parameter found in \"" + utils.UnsafeString(templateB[:currentPos]) + "\"")
|
||||
}
|
||||
funcChain = append(funcChain, logFunc)
|
||||
// add param to the fixParts
|
||||
fixParts = append(fixParts, templateB[index+1:currentPos])
|
||||
} else if logFunc, ok := tagFunctions[utils.UnsafeString(templateB[:currentPos])]; ok {
|
||||
// add functions without parameter
|
||||
funcChain = append(funcChain, logFunc)
|
||||
|
@ -63,5 +66,5 @@ func buildLogFuncChain(cfg *Config, tagFunctions map[string]LogFunc) (fixParts [
|
|||
funcChain = append(funcChain, nil)
|
||||
fixParts = append(fixParts, templateB)
|
||||
|
||||
return
|
||||
return fixParts, funcChain, nil
|
||||
}
|
||||
|
|
|
@ -41,36 +41,48 @@ type Config struct {
|
|||
// ChartJsURL for specify ChartJS library path or URL . also you can use relative path
|
||||
//
|
||||
// Optional. Default: https://cdn.jsdelivr.net/npm/chart.js@2.9/dist/Chart.bundle.min.js
|
||||
ChartJsURL string
|
||||
ChartJSURL string
|
||||
|
||||
index string
|
||||
}
|
||||
|
||||
//nolint:gochecknoglobals // Using a global var is fine here
|
||||
var ConfigDefault = Config{
|
||||
Title: defaultTitle,
|
||||
Refresh: defaultRefresh,
|
||||
FontURL: defaultFontURL,
|
||||
ChartJsURL: defaultChartJsURL,
|
||||
ChartJSURL: defaultChartJSURL,
|
||||
CustomHead: defaultCustomHead,
|
||||
APIOnly: false,
|
||||
Next: nil,
|
||||
index: newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJsURL,
|
||||
defaultCustomHead}),
|
||||
index: newIndex(viewBag{
|
||||
defaultTitle,
|
||||
defaultRefresh,
|
||||
defaultFontURL,
|
||||
defaultChartJSURL,
|
||||
defaultCustomHead,
|
||||
}),
|
||||
}
|
||||
|
||||
func configDefault(config ...Config) Config {
|
||||
// Users can change ConfigDefault.Title/Refresh which then
|
||||
// become incompatible with ConfigDefault.index
|
||||
if ConfigDefault.Title != defaultTitle || ConfigDefault.Refresh != defaultRefresh ||
|
||||
ConfigDefault.FontURL != defaultFontURL || ConfigDefault.ChartJsURL != defaultChartJsURL ||
|
||||
if ConfigDefault.Title != defaultTitle ||
|
||||
ConfigDefault.Refresh != defaultRefresh ||
|
||||
ConfigDefault.FontURL != defaultFontURL ||
|
||||
ConfigDefault.ChartJSURL != defaultChartJSURL ||
|
||||
ConfigDefault.CustomHead != defaultCustomHead {
|
||||
|
||||
if ConfigDefault.Refresh < minRefresh {
|
||||
ConfigDefault.Refresh = minRefresh
|
||||
}
|
||||
// update default index with new default title/refresh
|
||||
ConfigDefault.index = newIndex(viewBag{ConfigDefault.Title,
|
||||
ConfigDefault.Refresh, ConfigDefault.FontURL, ConfigDefault.ChartJsURL, ConfigDefault.CustomHead})
|
||||
ConfigDefault.index = newIndex(viewBag{
|
||||
ConfigDefault.Title,
|
||||
ConfigDefault.Refresh,
|
||||
ConfigDefault.FontURL,
|
||||
ConfigDefault.ChartJSURL,
|
||||
ConfigDefault.CustomHead,
|
||||
})
|
||||
}
|
||||
|
||||
// Return default config if nothing provided
|
||||
|
@ -93,8 +105,8 @@ func configDefault(config ...Config) Config {
|
|||
cfg.FontURL = defaultFontURL
|
||||
}
|
||||
|
||||
if cfg.ChartJsURL == "" {
|
||||
cfg.ChartJsURL = defaultChartJsURL
|
||||
if cfg.ChartJSURL == "" {
|
||||
cfg.ChartJSURL = defaultChartJSURL
|
||||
}
|
||||
if cfg.Refresh < minRefresh {
|
||||
cfg.Refresh = minRefresh
|
||||
|
@ -112,8 +124,8 @@ func configDefault(config ...Config) Config {
|
|||
cfg.index = newIndex(viewBag{
|
||||
title: cfg.Title,
|
||||
refresh: cfg.Refresh,
|
||||
fontUrl: cfg.FontURL,
|
||||
chartJsUrl: cfg.ChartJsURL,
|
||||
fontURL: cfg.FontURL,
|
||||
chartJSURL: cfg.ChartJSURL,
|
||||
customHead: cfg.CustomHead,
|
||||
})
|
||||
|
||||
|
|
|
@ -18,11 +18,11 @@ func Test_Config_Default(t *testing.T) {
|
|||
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
||||
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
|
||||
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
||||
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL)
|
||||
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
|
||||
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
||||
utils.AssertEqual(t, false, cfg.APIOnly)
|
||||
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJsURL, defaultCustomHead}), cfg.index)
|
||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
|
||||
})
|
||||
|
||||
t.Run("set title", func(t *testing.T) {
|
||||
|
@ -35,11 +35,11 @@ func Test_Config_Default(t *testing.T) {
|
|||
utils.AssertEqual(t, title, cfg.Title)
|
||||
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
|
||||
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
||||
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL)
|
||||
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
|
||||
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
||||
utils.AssertEqual(t, false, cfg.APIOnly)
|
||||
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
||||
utils.AssertEqual(t, newIndex(viewBag{title, defaultRefresh, defaultFontURL, defaultChartJsURL, defaultCustomHead}), cfg.index)
|
||||
utils.AssertEqual(t, newIndex(viewBag{title, defaultRefresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
|
||||
})
|
||||
|
||||
t.Run("set refresh less than default", func(t *testing.T) {
|
||||
|
@ -51,11 +51,11 @@ func Test_Config_Default(t *testing.T) {
|
|||
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
||||
utils.AssertEqual(t, minRefresh, cfg.Refresh)
|
||||
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
||||
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL)
|
||||
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
|
||||
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
||||
utils.AssertEqual(t, false, cfg.APIOnly)
|
||||
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, minRefresh, defaultFontURL, defaultChartJsURL, defaultCustomHead}), cfg.index)
|
||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, minRefresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
|
||||
})
|
||||
|
||||
t.Run("set refresh", func(t *testing.T) {
|
||||
|
@ -68,45 +68,45 @@ func Test_Config_Default(t *testing.T) {
|
|||
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
||||
utils.AssertEqual(t, refresh, cfg.Refresh)
|
||||
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
||||
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL)
|
||||
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
|
||||
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
||||
utils.AssertEqual(t, false, cfg.APIOnly)
|
||||
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, refresh, defaultFontURL, defaultChartJsURL, defaultCustomHead}), cfg.index)
|
||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, refresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
|
||||
})
|
||||
|
||||
t.Run("set font url", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fontUrl := "https://example.com"
|
||||
fontURL := "https://example.com"
|
||||
cfg := configDefault(Config{
|
||||
FontURL: fontUrl,
|
||||
FontURL: fontURL,
|
||||
})
|
||||
|
||||
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
||||
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
|
||||
utils.AssertEqual(t, fontUrl, cfg.FontURL)
|
||||
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL)
|
||||
utils.AssertEqual(t, fontURL, cfg.FontURL)
|
||||
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
|
||||
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
||||
utils.AssertEqual(t, false, cfg.APIOnly)
|
||||
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, fontUrl, defaultChartJsURL, defaultCustomHead}), cfg.index)
|
||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, fontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
|
||||
})
|
||||
|
||||
t.Run("set chart js url", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
chartUrl := "http://example.com"
|
||||
chartURL := "http://example.com"
|
||||
cfg := configDefault(Config{
|
||||
ChartJsURL: chartUrl,
|
||||
ChartJSURL: chartURL,
|
||||
})
|
||||
|
||||
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
||||
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
|
||||
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
||||
utils.AssertEqual(t, chartUrl, cfg.ChartJsURL)
|
||||
utils.AssertEqual(t, chartURL, cfg.ChartJSURL)
|
||||
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
||||
utils.AssertEqual(t, false, cfg.APIOnly)
|
||||
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, chartUrl, defaultCustomHead}), cfg.index)
|
||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, chartURL, defaultCustomHead}), cfg.index)
|
||||
})
|
||||
|
||||
t.Run("set custom head", func(t *testing.T) {
|
||||
|
@ -119,11 +119,11 @@ func Test_Config_Default(t *testing.T) {
|
|||
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
||||
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
|
||||
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
||||
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL)
|
||||
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
|
||||
utils.AssertEqual(t, head, cfg.CustomHead)
|
||||
utils.AssertEqual(t, false, cfg.APIOnly)
|
||||
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJsURL, head}), cfg.index)
|
||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJSURL, head}), cfg.index)
|
||||
})
|
||||
|
||||
t.Run("set api only", func(t *testing.T) {
|
||||
|
@ -135,11 +135,11 @@ func Test_Config_Default(t *testing.T) {
|
|||
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
||||
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
|
||||
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
||||
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL)
|
||||
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
|
||||
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
||||
utils.AssertEqual(t, true, cfg.APIOnly)
|
||||
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJsURL, defaultCustomHead}), cfg.index)
|
||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
|
||||
})
|
||||
|
||||
t.Run("set next", func(t *testing.T) {
|
||||
|
@ -154,10 +154,10 @@ func Test_Config_Default(t *testing.T) {
|
|||
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
||||
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
|
||||
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
||||
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL)
|
||||
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
|
||||
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
||||
utils.AssertEqual(t, false, cfg.APIOnly)
|
||||
utils.AssertEqual(t, f(nil), cfg.Next(nil))
|
||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJsURL, defaultCustomHead}), cfg.index)
|
||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -9,23 +9,22 @@ import (
|
|||
type viewBag struct {
|
||||
title string
|
||||
refresh time.Duration
|
||||
fontUrl string
|
||||
chartJsUrl string
|
||||
fontURL string
|
||||
chartJSURL string
|
||||
customHead string
|
||||
}
|
||||
|
||||
// returns index with new title/refresh
|
||||
func newIndex(dat viewBag) string {
|
||||
|
||||
timeout := dat.refresh.Milliseconds() - timeoutDiff
|
||||
if timeout < timeoutDiff {
|
||||
timeout = timeoutDiff
|
||||
}
|
||||
ts := strconv.FormatInt(timeout, 10)
|
||||
replacer := strings.NewReplacer("$TITLE", dat.title, "$TIMEOUT", ts,
|
||||
"$FONT_URL", dat.fontUrl, "$CHART_JS_URL", dat.chartJsUrl, "$CUSTOM_HEAD", dat.customHead,
|
||||
"$FONT_URL", dat.fontURL, "$CHART_JS_URL", dat.chartJSURL, "$CUSTOM_HEAD", dat.customHead,
|
||||
)
|
||||
return replacer.Replace(indexHtml)
|
||||
return replacer.Replace(indexHTML)
|
||||
}
|
||||
|
||||
const (
|
||||
|
@ -35,11 +34,11 @@ const (
|
|||
timeoutDiff = 200 // timeout will be Refresh (in milliseconds) - timeoutDiff
|
||||
minRefresh = timeoutDiff * time.Millisecond
|
||||
defaultFontURL = `https://fonts.googleapis.com/css2?family=Roboto:wght@400;900&display=swap`
|
||||
defaultChartJsURL = `https://cdn.jsdelivr.net/npm/chart.js@2.9/dist/Chart.bundle.min.js`
|
||||
defaultChartJSURL = `https://cdn.jsdelivr.net/npm/chart.js@2.9/dist/Chart.bundle.min.js`
|
||||
defaultCustomHead = ``
|
||||
|
||||
// parametrized by $TITLE and $TIMEOUT
|
||||
indexHtml = `<!DOCTYPE html>
|
||||
indexHTML = `<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
|
|
|
@ -33,18 +33,20 @@ type statsOS struct {
|
|||
Conns int `json:"conns"`
|
||||
}
|
||||
|
||||
//nolint:gochecknoglobals // TODO: Do not use a global var here
|
||||
var (
|
||||
monitPidCpu atomic.Value
|
||||
monitPidRam atomic.Value
|
||||
monitPidConns atomic.Value
|
||||
monitPIDCPU atomic.Value
|
||||
monitPIDRAM atomic.Value
|
||||
monitPIDConns atomic.Value
|
||||
|
||||
monitOsCpu atomic.Value
|
||||
monitOsRam atomic.Value
|
||||
monitOsTotalRam atomic.Value
|
||||
monitOsLoadAvg atomic.Value
|
||||
monitOsConns atomic.Value
|
||||
monitOSCPU atomic.Value
|
||||
monitOSRAM atomic.Value
|
||||
monitOSTotalRAM atomic.Value
|
||||
monitOSLoadAvg atomic.Value
|
||||
monitOSConns atomic.Value
|
||||
)
|
||||
|
||||
//nolint:gochecknoglobals // TODO: Do not use a global var here
|
||||
var (
|
||||
mutex sync.RWMutex
|
||||
once sync.Once
|
||||
|
@ -58,7 +60,7 @@ func New(config ...Config) fiber.Handler {
|
|||
|
||||
// Start routine to update statistics
|
||||
once.Do(func() {
|
||||
p, _ := process.NewProcess(int32(os.Getpid()))
|
||||
p, _ := process.NewProcess(int32(os.Getpid())) //nolint:errcheck // TODO: Handle error
|
||||
|
||||
updateStatistics(p)
|
||||
|
||||
|
@ -72,6 +74,7 @@ func New(config ...Config) fiber.Handler {
|
|||
})
|
||||
|
||||
// Return new handler
|
||||
//nolint:errcheck // Ignore the type-assertion errors
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
|
@ -83,15 +86,15 @@ func New(config ...Config) fiber.Handler {
|
|||
}
|
||||
if c.Get(fiber.HeaderAccept) == fiber.MIMEApplicationJSON || cfg.APIOnly {
|
||||
mutex.Lock()
|
||||
data.PID.CPU = monitPidCpu.Load().(float64)
|
||||
data.PID.RAM = monitPidRam.Load().(uint64)
|
||||
data.PID.Conns = monitPidConns.Load().(int)
|
||||
data.PID.CPU, _ = monitPIDCPU.Load().(float64)
|
||||
data.PID.RAM, _ = monitPIDRAM.Load().(uint64)
|
||||
data.PID.Conns, _ = monitPIDConns.Load().(int)
|
||||
|
||||
data.OS.CPU = monitOsCpu.Load().(float64)
|
||||
data.OS.RAM = monitOsRam.Load().(uint64)
|
||||
data.OS.TotalRAM = monitOsTotalRam.Load().(uint64)
|
||||
data.OS.LoadAvg = monitOsLoadAvg.Load().(float64)
|
||||
data.OS.Conns = monitOsConns.Load().(int)
|
||||
data.OS.CPU, _ = monitOSCPU.Load().(float64)
|
||||
data.OS.RAM, _ = monitOSRAM.Load().(uint64)
|
||||
data.OS.TotalRAM, _ = monitOSTotalRAM.Load().(uint64)
|
||||
data.OS.LoadAvg, _ = monitOSLoadAvg.Load().(float64)
|
||||
data.OS.Conns, _ = monitOSConns.Load().(int)
|
||||
mutex.Unlock()
|
||||
return c.Status(fiber.StatusOK).JSON(data)
|
||||
}
|
||||
|
@ -101,29 +104,35 @@ func New(config ...Config) fiber.Handler {
|
|||
}
|
||||
|
||||
func updateStatistics(p *process.Process) {
|
||||
pidCpu, _ := p.CPUPercent()
|
||||
monitPidCpu.Store(pidCpu / 10)
|
||||
|
||||
if osCpu, _ := cpu.Percent(0, false); len(osCpu) > 0 {
|
||||
monitOsCpu.Store(osCpu[0])
|
||||
pidCPU, err := p.CPUPercent()
|
||||
if err != nil {
|
||||
monitPIDCPU.Store(pidCPU / 10) //nolint:gomnd // TODO: Explain why we divide by 10 here
|
||||
}
|
||||
|
||||
if pidMem, _ := p.MemoryInfo(); pidMem != nil {
|
||||
monitPidRam.Store(pidMem.RSS)
|
||||
if osCPU, err := cpu.Percent(0, false); err != nil && len(osCPU) > 0 {
|
||||
monitOSCPU.Store(osCPU[0])
|
||||
}
|
||||
|
||||
if osMem, _ := mem.VirtualMemory(); osMem != nil {
|
||||
monitOsRam.Store(osMem.Used)
|
||||
monitOsTotalRam.Store(osMem.Total)
|
||||
if pidRAM, err := p.MemoryInfo(); err != nil && pidRAM != nil {
|
||||
monitPIDRAM.Store(pidRAM.RSS)
|
||||
}
|
||||
|
||||
if loadAvg, _ := load.Avg(); loadAvg != nil {
|
||||
monitOsLoadAvg.Store(loadAvg.Load1)
|
||||
if osRAM, err := mem.VirtualMemory(); err != nil && osRAM != nil {
|
||||
monitOSRAM.Store(osRAM.Used)
|
||||
monitOSTotalRAM.Store(osRAM.Total)
|
||||
}
|
||||
|
||||
pidConns, _ := net.ConnectionsPid("tcp", p.Pid)
|
||||
monitPidConns.Store(len(pidConns))
|
||||
if loadAvg, err := load.Avg(); err != nil && loadAvg != nil {
|
||||
monitOSLoadAvg.Store(loadAvg.Load1)
|
||||
}
|
||||
|
||||
osConns, _ := net.Connections("tcp")
|
||||
monitOsConns.Store(len(osConns))
|
||||
pidConns, err := net.ConnectionsPid("tcp", p.Pid)
|
||||
if err != nil {
|
||||
monitPIDConns.Store(len(pidConns))
|
||||
}
|
||||
|
||||
osConns, err := net.Connections("tcp")
|
||||
if err != nil {
|
||||
monitOSConns.Store(len(osConns))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
|
@ -61,6 +62,7 @@ func Test_Monitor_Html(t *testing.T) {
|
|||
conf.Refresh.Milliseconds()-timeoutDiff)
|
||||
utils.AssertEqual(t, true, bytes.Contains(buf, []byte(timeoutLine)))
|
||||
}
|
||||
|
||||
func Test_Monitor_Html_CustomCodes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -82,8 +84,10 @@ func Test_Monitor_Html_CustomCodes(t *testing.T) {
|
|||
utils.AssertEqual(t, true, bytes.Contains(buf, []byte(timeoutLine)))
|
||||
|
||||
// custom config
|
||||
conf := Config{Title: "New " + defaultTitle, Refresh: defaultRefresh + time.Second,
|
||||
ChartJsURL: "https://cdnjs.com/libraries/Chart.js",
|
||||
conf := Config{
|
||||
Title: "New " + defaultTitle,
|
||||
Refresh: defaultRefresh + time.Second,
|
||||
ChartJSURL: "https://cdnjs.com/libraries/Chart.js",
|
||||
FontURL: "/public/my-font.css",
|
||||
CustomHead: `<style>body{background:#fff}</style>`,
|
||||
}
|
||||
|
@ -136,7 +140,7 @@ func Benchmark_Monitor(b *testing.B) {
|
|||
h := app.Handler()
|
||||
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
fctx.Request.Header.SetMethod("GET")
|
||||
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
fctx.Request.SetRequestURI("/")
|
||||
fctx.Request.Header.Set(fiber.HeaderAccept, fiber.MIMEApplicationJSON)
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package pprof
|
||||
|
||||
import "github.com/gofiber/fiber/v2"
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
|
@ -17,6 +19,7 @@ type Config struct {
|
|||
Prefix string
|
||||
}
|
||||
|
||||
//nolint:gochecknoglobals // Using a global var is fine here
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
}
|
||||
|
|
|
@ -5,22 +5,8 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/valyala/fasthttp/fasthttpadaptor"
|
||||
)
|
||||
|
||||
// Set pprof adaptors
|
||||
var (
|
||||
pprofIndex = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Index)
|
||||
pprofCmdline = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Cmdline)
|
||||
pprofProfile = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Profile)
|
||||
pprofSymbol = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Symbol)
|
||||
pprofTrace = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Trace)
|
||||
pprofAllocs = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("allocs").ServeHTTP)
|
||||
pprofBlock = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("block").ServeHTTP)
|
||||
pprofGoroutine = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("goroutine").ServeHTTP)
|
||||
pprofHeap = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("heap").ServeHTTP)
|
||||
pprofMutex = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("mutex").ServeHTTP)
|
||||
pprofThreadcreate = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("threadcreate").ServeHTTP)
|
||||
"github.com/valyala/fasthttp/fasthttpadaptor"
|
||||
)
|
||||
|
||||
// New creates a new middleware handler
|
||||
|
@ -28,6 +14,21 @@ func New(config ...Config) fiber.Handler {
|
|||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Set pprof adaptors
|
||||
var (
|
||||
pprofIndex = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Index)
|
||||
pprofCmdline = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Cmdline)
|
||||
pprofProfile = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Profile)
|
||||
pprofSymbol = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Symbol)
|
||||
pprofTrace = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Trace)
|
||||
pprofAllocs = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("allocs").ServeHTTP)
|
||||
pprofBlock = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("block").ServeHTTP)
|
||||
pprofGoroutine = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("goroutine").ServeHTTP)
|
||||
pprofHeap = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("heap").ServeHTTP)
|
||||
pprofMutex = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("mutex").ServeHTTP)
|
||||
pprofThreadcreate = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("threadcreate").ServeHTTP)
|
||||
)
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
|
@ -48,7 +49,7 @@ type Config struct {
|
|||
WriteBufferSize int
|
||||
|
||||
// tls config for the http client.
|
||||
TlsConfig *tls.Config
|
||||
TlsConfig *tls.Config //nolint:stylecheck,revive // TODO: Rename to "TLSConfig" in v3
|
||||
|
||||
// Client is custom client when client config is complex.
|
||||
// Note that Servers, Timeout, WriteBufferSize, ReadBufferSize and TlsConfig
|
||||
|
@ -57,6 +58,8 @@ type Config struct {
|
|||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
//
|
||||
//nolint:gochecknoglobals // Using a global var is fine here
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
ModifyRequest: nil,
|
||||
|
|
|
@ -3,19 +3,20 @@ package proxy
|
|||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// New is deprecated
|
||||
func New(config Config) fiber.Handler {
|
||||
fmt.Println("proxy.New is deprecated, please use proxy.Balancer instead")
|
||||
log.Printf("proxy.New is deprecated, please use proxy.Balancer instead\n")
|
||||
return Balancer(config)
|
||||
}
|
||||
|
||||
|
@ -25,7 +26,7 @@ func Balancer(config Config) fiber.Handler {
|
|||
cfg := configDefault(config)
|
||||
|
||||
// Load balanced client
|
||||
var lbc = &fasthttp.LBClient{}
|
||||
lbc := &fasthttp.LBClient{}
|
||||
// Note that Servers, Timeout, WriteBufferSize, ReadBufferSize and TlsConfig
|
||||
// will not be used if the client are set.
|
||||
if config.Client == nil {
|
||||
|
@ -61,7 +62,7 @@ func Balancer(config Config) fiber.Handler {
|
|||
}
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) (err error) {
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
|
@ -76,7 +77,7 @@ func Balancer(config Config) fiber.Handler {
|
|||
|
||||
// Modify request
|
||||
if cfg.ModifyRequest != nil {
|
||||
if err = cfg.ModifyRequest(c); err != nil {
|
||||
if err := cfg.ModifyRequest(c); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -84,7 +85,7 @@ func Balancer(config Config) fiber.Handler {
|
|||
req.SetRequestURI(utils.UnsafeString(req.RequestURI()))
|
||||
|
||||
// Forward request
|
||||
if err = lbc.Do(req, res); err != nil {
|
||||
if err := lbc.Do(req, res); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -93,7 +94,7 @@ func Balancer(config Config) fiber.Handler {
|
|||
|
||||
// Modify response
|
||||
if cfg.ModifyResponse != nil {
|
||||
if err = cfg.ModifyResponse(c); err != nil {
|
||||
if err := cfg.ModifyResponse(c); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -103,16 +104,20 @@ func Balancer(config Config) fiber.Handler {
|
|||
}
|
||||
}
|
||||
|
||||
//nolint:gochecknoglobals // TODO: Do not use a global var here
|
||||
var client = &fasthttp.Client{
|
||||
NoDefaultUserAgentHeader: true,
|
||||
DisablePathNormalizing: true,
|
||||
}
|
||||
|
||||
//nolint:gochecknoglobals // TODO: Do not use a global var here
|
||||
var lock sync.RWMutex
|
||||
|
||||
// WithTlsConfig update http client with a user specified tls.config
|
||||
// This function should be called before Do and Forward.
|
||||
// Deprecated: use WithClient instead.
|
||||
//
|
||||
//nolint:stylecheck,revive // TODO: Rename to "WithTLSConfig" in v3
|
||||
func WithTlsConfig(tlsConfig *tls.Config) {
|
||||
client.TLSConfig = tlsConfig
|
||||
}
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
@ -13,10 +12,11 @@ import (
|
|||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/tlstest"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func createProxyTestServer(handler fiber.Handler, t *testing.T) (*fiber.App, string) {
|
||||
func createProxyTestServer(t *testing.T, handler fiber.Handler) (*fiber.App, string) {
|
||||
t.Helper()
|
||||
|
||||
target := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||
|
@ -60,7 +60,7 @@ func Test_Proxy_Next(t *testing.T) {
|
|||
},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
@ -69,11 +69,11 @@ func Test_Proxy_Next(t *testing.T) {
|
|||
func Test_Proxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
target, addr := createProxyTestServer(
|
||||
func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }, t,
|
||||
)
|
||||
target, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusTeapot)
|
||||
})
|
||||
|
||||
resp, err := target.Test(httptest.NewRequest("GET", "/", nil), 2000)
|
||||
resp, err := target.Test(httptest.NewRequest(fiber.MethodGet, "/", nil), 2000)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
|
||||
|
||||
|
@ -81,7 +81,7 @@ func Test_Proxy(t *testing.T) {
|
|||
|
||||
app.Use(Balancer(Config{Servers: []string{addr}}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
req.Host = addr
|
||||
resp, err = app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
@ -107,7 +107,7 @@ func Test_Proxy_Balancer_WithTlsConfig(t *testing.T) {
|
|||
})
|
||||
|
||||
addr := ln.Addr().String()
|
||||
clientTLSConf := &tls.Config{InsecureSkipVerify: true}
|
||||
clientTLSConf := &tls.Config{InsecureSkipVerify: true} //nolint:gosec // We're in a test func, so this is fine
|
||||
|
||||
// disable certificate verification in Balancer
|
||||
app.Use(Balancer(Config{
|
||||
|
@ -128,9 +128,9 @@ func Test_Proxy_Balancer_WithTlsConfig(t *testing.T) {
|
|||
func Test_Proxy_Forward_WithTlsConfig_To_Http(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, targetAddr := createProxyTestServer(func(c *fiber.Ctx) error {
|
||||
_, targetAddr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||
return c.SendString("hello from target")
|
||||
}, t)
|
||||
})
|
||||
|
||||
proxyServerTLSConf, _, err := tlstest.GetTLSConfigs()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
@ -164,13 +164,13 @@ func Test_Proxy_Forward(t *testing.T) {
|
|||
|
||||
app := fiber.New()
|
||||
|
||||
_, addr := createProxyTestServer(
|
||||
func(c *fiber.Ctx) error { return c.SendString("forwarded") }, t,
|
||||
)
|
||||
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||
return c.SendString("forwarded")
|
||||
})
|
||||
|
||||
app.Use(Forward("http://" + addr))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
|
@ -198,7 +198,7 @@ func Test_Proxy_Forward_WithTlsConfig(t *testing.T) {
|
|||
})
|
||||
|
||||
addr := ln.Addr().String()
|
||||
clientTLSConf := &tls.Config{InsecureSkipVerify: true}
|
||||
clientTLSConf := &tls.Config{InsecureSkipVerify: true} //nolint:gosec // We're in a test func, so this is fine
|
||||
|
||||
// disable certificate verification
|
||||
WithTlsConfig(clientTLSConf)
|
||||
|
@ -217,9 +217,9 @@ func Test_Proxy_Forward_WithTlsConfig(t *testing.T) {
|
|||
func Test_Proxy_Modify_Response(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, addr := createProxyTestServer(func(c *fiber.Ctx) error {
|
||||
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||
return c.Status(500).SendString("not modified")
|
||||
}, t)
|
||||
})
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(Balancer(Config{
|
||||
|
@ -230,7 +230,7 @@ func Test_Proxy_Modify_Response(t *testing.T) {
|
|||
},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
|
@ -243,10 +243,10 @@ func Test_Proxy_Modify_Response(t *testing.T) {
|
|||
func Test_Proxy_Modify_Request(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, addr := createProxyTestServer(func(c *fiber.Ctx) error {
|
||||
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||
b := c.Request().Body()
|
||||
return c.SendString(string(b))
|
||||
}, t)
|
||||
})
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(Balancer(Config{
|
||||
|
@ -257,7 +257,7 @@ func Test_Proxy_Modify_Request(t *testing.T) {
|
|||
},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
|
@ -270,10 +270,10 @@ func Test_Proxy_Modify_Request(t *testing.T) {
|
|||
func Test_Proxy_Timeout_Slow_Server(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, addr := createProxyTestServer(func(c *fiber.Ctx) error {
|
||||
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||
time.Sleep(2 * time.Second)
|
||||
return c.SendString("fiber is awesome")
|
||||
}, t)
|
||||
})
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(Balancer(Config{
|
||||
|
@ -281,7 +281,7 @@ func Test_Proxy_Timeout_Slow_Server(t *testing.T) {
|
|||
Timeout: 3 * time.Second,
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil), 5000)
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil), 5000)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
|
@ -294,10 +294,10 @@ func Test_Proxy_Timeout_Slow_Server(t *testing.T) {
|
|||
func Test_Proxy_With_Timeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, addr := createProxyTestServer(func(c *fiber.Ctx) error {
|
||||
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||
time.Sleep(1 * time.Second)
|
||||
return c.SendString("fiber is awesome")
|
||||
}, t)
|
||||
})
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(Balancer(Config{
|
||||
|
@ -305,7 +305,7 @@ func Test_Proxy_With_Timeout(t *testing.T) {
|
|||
Timeout: 100 * time.Millisecond,
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil), 2000)
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil), 2000)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
|
||||
|
||||
|
@ -318,16 +318,16 @@ func Test_Proxy_With_Timeout(t *testing.T) {
|
|||
func Test_Proxy_Buffer_Size_Response(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, addr := createProxyTestServer(func(c *fiber.Ctx) error {
|
||||
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||
long := strings.Join(make([]string, 5000), "-")
|
||||
c.Set("Very-Long-Header", long)
|
||||
return c.SendString("ok")
|
||||
}, t)
|
||||
})
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(Balancer(Config{Servers: []string{addr}}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
|
||||
|
||||
|
@ -337,7 +337,7 @@ func Test_Proxy_Buffer_Size_Response(t *testing.T) {
|
|||
ReadBufferSize: 1024 * 8,
|
||||
}))
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
@ -357,9 +357,9 @@ func Test_Proxy_Do_RestoreOriginalURL(t *testing.T) {
|
|||
utils.AssertEqual(t, originalURL, c.OriginalURL())
|
||||
return c.SendString("ok")
|
||||
})
|
||||
_, err1 := app.Test(httptest.NewRequest("GET", "/test", nil))
|
||||
_, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
|
||||
// This test requires multiple requests due to zero allocation used in fiber
|
||||
_, err2 := app.Test(httptest.NewRequest("GET", "/test", nil))
|
||||
_, err2 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
|
||||
|
||||
utils.AssertEqual(t, nil, err1)
|
||||
utils.AssertEqual(t, nil, err2)
|
||||
|
@ -369,9 +369,9 @@ func Test_Proxy_Do_RestoreOriginalURL(t *testing.T) {
|
|||
func Test_Proxy_Do_HTTP_Prefix_URL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, addr := createProxyTestServer(func(c *fiber.Ctx) error {
|
||||
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||
return c.SendString("hello world")
|
||||
}, t)
|
||||
})
|
||||
|
||||
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||
app.Get("/*", func(c *fiber.Ctx) error {
|
||||
|
@ -386,7 +386,7 @@ func Test_Proxy_Do_HTTP_Prefix_URL(t *testing.T) {
|
|||
return nil
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/http://"+addr, nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/http://"+addr, nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
s, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
@ -431,9 +431,8 @@ func Test_Proxy_Forward_Local_Client(t *testing.T) {
|
|||
app.Use(Forward("http://"+addr+"/test_local_client", &fasthttp.Client{
|
||||
NoDefaultUserAgentHeader: true,
|
||||
DisablePathNormalizing: true,
|
||||
Dial: func(addr string) (net.Conn, error) {
|
||||
return fasthttp.Dial(addr)
|
||||
},
|
||||
|
||||
Dial: fasthttp.Dial,
|
||||
}))
|
||||
go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }()
|
||||
|
||||
|
@ -447,11 +446,11 @@ func Test_Proxy_Forward_Local_Client(t *testing.T) {
|
|||
func Test_ProxyBalancer_Custom_Client(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
target, addr := createProxyTestServer(
|
||||
func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }, t,
|
||||
)
|
||||
target, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusTeapot)
|
||||
})
|
||||
|
||||
resp, err := target.Test(httptest.NewRequest("GET", "/", nil), 2000)
|
||||
resp, err := target.Test(httptest.NewRequest(fiber.MethodGet, "/", nil), 2000)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
|
||||
|
||||
|
@ -468,7 +467,7 @@ func Test_ProxyBalancer_Custom_Client(t *testing.T) {
|
|||
Timeout: time.Second,
|
||||
}}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
req.Host = addr
|
||||
resp, err = app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package recover
|
||||
package recover //nolint:predeclared // TODO: Rename to some non-builtin
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
@ -23,6 +23,8 @@ type Config struct {
|
|||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
//
|
||||
//nolint:gochecknoglobals // Using a global var is fine here
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
EnableStackTrace: false,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package recover
|
||||
package recover //nolint:predeclared // TODO: Rename to some non-builtin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
@ -9,7 +9,7 @@ import (
|
|||
)
|
||||
|
||||
func defaultStackTraceHandler(_ *fiber.Ctx, e interface{}) {
|
||||
_, _ = os.Stderr.WriteString(fmt.Sprintf("panic: %v\n%s\n", e, debug.Stack()))
|
||||
_, _ = os.Stderr.WriteString(fmt.Sprintf("panic: %v\n%s\n", e, debug.Stack())) //nolint:errcheck // This will never fail
|
||||
}
|
||||
|
||||
// New creates a new middleware handler
|
||||
|
@ -18,7 +18,7 @@ func New(config ...Config) fiber.Handler {
|
|||
cfg := configDefault(config...)
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) (err error) {
|
||||
return func(c *fiber.Ctx) (err error) { //nolint:nonamedreturns // Uses recover() to overwrite the error
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package recover
|
||||
package recover //nolint:predeclared // TODO: Rename to some non-builtin
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
|
@ -24,7 +24,7 @@ func Test_Recover(t *testing.T) {
|
|||
panic("Hi, I'm an error!")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/panic", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/panic", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
|
||||
}
|
||||
|
@ -39,7 +39,7 @@ func Test_Recover_Next(t *testing.T) {
|
|||
},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
@ -55,7 +55,7 @@ func Test_Recover_EnableStackTrace(t *testing.T) {
|
|||
panic("Hi, I'm an error!")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/panic", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/panic", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
|
||||
}
|
||||
|
|
|
@ -33,6 +33,8 @@ type Config struct {
|
|||
// It uses a fast UUID generator which will expose the number of
|
||||
// requests made to the server. To conceal this value for better
|
||||
// privacy, use the "utils.UUIDv4" generator.
|
||||
//
|
||||
//nolint:gochecknoglobals // Using a global var is fine here
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
Header: fiber.HeaderXRequestID,
|
||||
|
|
|
@ -19,14 +19,14 @@ func Test_RequestID(t *testing.T) {
|
|||
return c.SendString("Hello, World 👋!")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
reqid := resp.Header.Get(fiber.HeaderXRequestID)
|
||||
utils.AssertEqual(t, 36, len(reqid))
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
req.Header.Add(fiber.HeaderXRequestID, reqid)
|
||||
|
||||
resp, err = app.Test(req)
|
||||
|
@ -45,7 +45,7 @@ func Test_RequestID_Next(t *testing.T) {
|
|||
},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, resp.Header.Get(fiber.HeaderXRequestID), "")
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
|
@ -54,13 +54,13 @@ func Test_RequestID_Next(t *testing.T) {
|
|||
// go test -run Test_RequestID_Locals
|
||||
func Test_RequestID_Locals(t *testing.T) {
|
||||
t.Parallel()
|
||||
reqId := "ThisIsARequestId"
|
||||
reqID := "ThisIsARequestId"
|
||||
ctxKey := "ThisIsAContextKey"
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
Generator: func() string {
|
||||
return reqId
|
||||
return reqID
|
||||
},
|
||||
ContextKey: ctxKey,
|
||||
}))
|
||||
|
@ -68,11 +68,11 @@ func Test_RequestID_Locals(t *testing.T) {
|
|||
var ctxVal string
|
||||
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
ctxVal = c.Locals(ctxKey).(string)
|
||||
ctxVal = c.Locals(ctxKey).(string) //nolint:forcetypeassert,errcheck // We always store a string in here
|
||||
return c.Next()
|
||||
})
|
||||
|
||||
_, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, reqId, ctxVal)
|
||||
utils.AssertEqual(t, reqID, ctxVal)
|
||||
}
|
||||
|
|
|
@ -32,7 +32,7 @@ func (s *Session) Save() error
|
|||
func (s *Session) Fresh() bool
|
||||
func (s *Session) ID() string
|
||||
func (s *Session) Keys() []string
|
||||
func (s *Session) SetExpiry(time.Duration)
|
||||
func (s *Session) SetExpiry(time.Duration)
|
||||
```
|
||||
|
||||
**⚠ _Storing `interface{}` values are limited to built-ins Go types_**
|
||||
|
@ -148,7 +148,7 @@ type Config struct {
|
|||
// Optional. Default value utils.UUID
|
||||
KeyGenerator func() string
|
||||
|
||||
// Deprecated, please use KeyLookup
|
||||
// Deprecated: Please use KeyLookup
|
||||
CookieName string
|
||||
|
||||
// Source defines where to obtain the session id
|
||||
|
|
|
@ -2,6 +2,7 @@ package session
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -49,7 +50,7 @@ type Config struct {
|
|||
// Optional. Default value utils.UUIDv4
|
||||
KeyGenerator func() string
|
||||
|
||||
// Deprecated, please use KeyLookup
|
||||
// Deprecated: Please use KeyLookup
|
||||
CookieName string
|
||||
|
||||
// Source defines where to obtain the session id
|
||||
|
@ -68,8 +69,10 @@ const (
|
|||
)
|
||||
|
||||
// ConfigDefault is the default config
|
||||
//
|
||||
//nolint:gochecknoglobals // Using a global var is fine here
|
||||
var ConfigDefault = Config{
|
||||
Expiration: 24 * time.Hour,
|
||||
Expiration: 24 * time.Hour, //nolint:gomnd // No magic number, just the default config
|
||||
KeyLookup: "cookie:session_id",
|
||||
KeyGenerator: utils.UUIDv4,
|
||||
source: "cookie",
|
||||
|
@ -91,7 +94,7 @@ func configDefault(config ...Config) Config {
|
|||
cfg.Expiration = ConfigDefault.Expiration
|
||||
}
|
||||
if cfg.CookieName != "" {
|
||||
fmt.Println("[session] CookieName is deprecated, please use KeyLookup")
|
||||
log.Printf("[session] CookieName is deprecated, please use KeyLookup\n")
|
||||
cfg.KeyLookup = fmt.Sprintf("cookie:%s", cfg.CookieName)
|
||||
}
|
||||
if cfg.KeyLookup == "" {
|
||||
|
@ -102,7 +105,8 @@ func configDefault(config ...Config) Config {
|
|||
}
|
||||
|
||||
selectors := strings.Split(cfg.KeyLookup, ":")
|
||||
if len(selectors) != 2 {
|
||||
const numSelectors = 2
|
||||
if len(selectors) != numSelectors {
|
||||
panic("[session] KeyLookup must in the form of <source>:<name>")
|
||||
}
|
||||
switch Source(selectors[0]) {
|
||||
|
|
|
@ -13,6 +13,7 @@ type data struct {
|
|||
Data map[string]interface{}
|
||||
}
|
||||
|
||||
//nolint:gochecknoglobals // TODO: Do not use a global var here
|
||||
var dataPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
d := new(data)
|
||||
|
@ -22,7 +23,7 @@ var dataPool = sync.Pool{
|
|||
}
|
||||
|
||||
func acquireData() *data {
|
||||
return dataPool.Get().(*data)
|
||||
return dataPool.Get().(*data) //nolint:forcetypeassert // We store nothing else in the pool
|
||||
}
|
||||
|
||||
func (d *data) Reset() {
|
||||
|
|
|
@ -3,11 +3,13 @@ package session
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
|
@ -21,6 +23,7 @@ type Session struct {
|
|||
exp time.Duration // expiration of this session
|
||||
}
|
||||
|
||||
//nolint:gochecknoglobals // TODO: Do not use a global var here
|
||||
var sessionPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(Session)
|
||||
|
@ -28,7 +31,7 @@ var sessionPool = sync.Pool{
|
|||
}
|
||||
|
||||
func acquireSession() *Session {
|
||||
s := sessionPool.Get().(*Session)
|
||||
s := sessionPool.Get().(*Session) //nolint:forcetypeassert,errcheck // We store nothing else in the pool
|
||||
if s.data == nil {
|
||||
s.data = acquireData()
|
||||
}
|
||||
|
@ -153,7 +156,7 @@ func (s *Session) Save() error {
|
|||
encCache := gob.NewEncoder(s.byteBuffer)
|
||||
err := encCache.Encode(&s.data.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to encode data: %w", err)
|
||||
}
|
||||
|
||||
// copy the data in buffer
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
|
@ -94,6 +95,8 @@ func Test_Session(t *testing.T) {
|
|||
}
|
||||
|
||||
// go test -run Test_Session_Types
|
||||
//
|
||||
//nolint:forcetypeassert // TODO: Do not force-type assert
|
||||
func Test_Session_Types(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -127,25 +130,27 @@ func Test_Session_Types(t *testing.T) {
|
|||
Name: "John",
|
||||
}
|
||||
// set value
|
||||
var vbool = true
|
||||
var vstring = "str"
|
||||
var vint = 13
|
||||
var vint8 int8 = 13
|
||||
var vint16 int16 = 13
|
||||
var vint32 int32 = 13
|
||||
var vint64 int64 = 13
|
||||
var vuint uint = 13
|
||||
var vuint8 uint8 = 13
|
||||
var vuint16 uint16 = 13
|
||||
var vuint32 uint32 = 13
|
||||
var vuint64 uint64 = 13
|
||||
var vuintptr uintptr = 13
|
||||
var vbyte byte = 'k'
|
||||
var vrune rune = 'k'
|
||||
var vfloat32 float32 = 13
|
||||
var vfloat64 float64 = 13
|
||||
var vcomplex64 complex64 = 13
|
||||
var vcomplex128 complex128 = 13
|
||||
var (
|
||||
vbool = true
|
||||
vstring = "str"
|
||||
vint = 13
|
||||
vint8 int8 = 13
|
||||
vint16 int16 = 13
|
||||
vint32 int32 = 13
|
||||
vint64 int64 = 13
|
||||
vuint uint = 13
|
||||
vuint8 uint8 = 13
|
||||
vuint16 uint16 = 13
|
||||
vuint32 uint32 = 13
|
||||
vuint64 uint64 = 13
|
||||
vuintptr uintptr = 13
|
||||
vbyte byte = 'k'
|
||||
vrune = 'k'
|
||||
vfloat32 float32 = 13
|
||||
vfloat64 float64 = 13
|
||||
vcomplex64 complex64 = 13
|
||||
vcomplex128 complex128 = 13
|
||||
)
|
||||
sess.Set("vuser", vuser)
|
||||
sess.Set("vbool", vbool)
|
||||
sess.Set("vstring", vstring)
|
||||
|
@ -212,7 +217,8 @@ func Test_Session_Store_Reset(t *testing.T) {
|
|||
defer app.ReleaseCtx(ctx)
|
||||
|
||||
// get session
|
||||
sess, _ := store.Get(ctx)
|
||||
sess, err := store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
// make sure its new
|
||||
utils.AssertEqual(t, true, sess.Fresh())
|
||||
// set value & save
|
||||
|
@ -224,7 +230,8 @@ func Test_Session_Store_Reset(t *testing.T) {
|
|||
utils.AssertEqual(t, nil, store.Reset())
|
||||
|
||||
// make sure the session is recreated
|
||||
sess, _ = store.Get(ctx)
|
||||
sess, err = store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, sess.Fresh())
|
||||
utils.AssertEqual(t, nil, sess.Get("hello"))
|
||||
}
|
||||
|
@ -242,12 +249,13 @@ func Test_Session_Save(t *testing.T) {
|
|||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
// get session
|
||||
sess, _ := store.Get(ctx)
|
||||
sess, err := store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
// set value
|
||||
sess.Set("name", "john")
|
||||
|
||||
// save session
|
||||
err := sess.Save()
|
||||
err = sess.Save()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
})
|
||||
|
||||
|
@ -262,12 +270,13 @@ func Test_Session_Save(t *testing.T) {
|
|||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
// get session
|
||||
sess, _ := store.Get(ctx)
|
||||
sess, err := store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
// set value
|
||||
sess.Set("name", "john")
|
||||
|
||||
// save session
|
||||
err := sess.Save()
|
||||
err = sess.Save()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, store.getSessionID(ctx), string(ctx.Response().Header.Peek(store.sessionName)))
|
||||
utils.AssertEqual(t, store.getSessionID(ctx), string(ctx.Request().Header.Peek(store.sessionName)))
|
||||
|
@ -287,7 +296,8 @@ func Test_Session_Save_Expiration(t *testing.T) {
|
|||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
// get session
|
||||
sess, _ := store.Get(ctx)
|
||||
sess, err := store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
// set value
|
||||
sess.Set("name", "john")
|
||||
|
||||
|
@ -295,18 +305,20 @@ func Test_Session_Save_Expiration(t *testing.T) {
|
|||
sess.SetExpiry(time.Second * 5)
|
||||
|
||||
// save session
|
||||
err := sess.Save()
|
||||
err = sess.Save()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
// here you need to get the old session yet
|
||||
sess, _ = store.Get(ctx)
|
||||
sess, err = store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "john", sess.Get("name"))
|
||||
|
||||
// just to make sure the session has been expired
|
||||
time.Sleep(time.Second * 5)
|
||||
|
||||
// here you should get a new session
|
||||
sess, _ = store.Get(ctx)
|
||||
sess, err = store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, nil, sess.Get("name"))
|
||||
})
|
||||
}
|
||||
|
@ -325,7 +337,8 @@ func Test_Session_Reset(t *testing.T) {
|
|||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
// get session
|
||||
sess, _ := store.Get(ctx)
|
||||
sess, err := store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
sess.Set("name", "fenny")
|
||||
utils.AssertEqual(t, nil, sess.Destroy())
|
||||
|
@ -345,14 +358,16 @@ func Test_Session_Reset(t *testing.T) {
|
|||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
// get session
|
||||
sess, _ := store.Get(ctx)
|
||||
sess, err := store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
// set value & save
|
||||
sess.Set("name", "fenny")
|
||||
utils.AssertEqual(t, nil, sess.Save())
|
||||
sess, _ = store.Get(ctx)
|
||||
sess, err = store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
err := sess.Destroy()
|
||||
err = sess.Destroy()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "", string(ctx.Response().Header.Peek(store.sessionName)))
|
||||
utils.AssertEqual(t, "", string(ctx.Request().Header.Peek(store.sessionName)))
|
||||
|
@ -383,7 +398,8 @@ func Test_Session_Cookie(t *testing.T) {
|
|||
defer app.ReleaseCtx(ctx)
|
||||
|
||||
// get session
|
||||
sess, _ := store.Get(ctx)
|
||||
sess, err := store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, nil, sess.Save())
|
||||
|
||||
// cookie should be set on Save ( even if empty data )
|
||||
|
@ -401,12 +417,14 @@ func Test_Session_Cookie_In_Response(t *testing.T) {
|
|||
defer app.ReleaseCtx(ctx)
|
||||
|
||||
// get session
|
||||
sess, _ := store.Get(ctx)
|
||||
sess, err := store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
sess.Set("id", "1")
|
||||
utils.AssertEqual(t, true, sess.Fresh())
|
||||
utils.AssertEqual(t, nil, sess.Save())
|
||||
|
||||
sess, _ = store.Get(ctx)
|
||||
sess, err = store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
sess.Set("name", "john")
|
||||
utils.AssertEqual(t, true, sess.Fresh())
|
||||
|
||||
|
@ -497,7 +515,7 @@ func Benchmark_Session(b *testing.B) {
|
|||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
sess, _ := store.Get(c)
|
||||
sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
|
||||
sess.Set("john", "doe")
|
||||
err = sess.Save()
|
||||
}
|
||||
|
@ -512,7 +530,7 @@ func Benchmark_Session(b *testing.B) {
|
|||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
sess, _ := store.Get(c)
|
||||
sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
|
||||
sess.Set("john", "doe")
|
||||
err = sess.Save()
|
||||
}
|
||||
|
|
|
@ -2,11 +2,13 @@ package session
|
|||
|
||||
import (
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
|
@ -14,6 +16,7 @@ type Store struct {
|
|||
Config
|
||||
}
|
||||
|
||||
//nolint:gochecknoglobals // TODO: Do not use a global var here
|
||||
var mux sync.Mutex
|
||||
|
||||
func New(config ...Config) *Store {
|
||||
|
@ -31,7 +34,7 @@ func New(config ...Config) *Store {
|
|||
|
||||
// RegisterType will allow you to encode/decode custom types
|
||||
// into any Storage provider
|
||||
func (s *Store) RegisterType(i interface{}) {
|
||||
func (*Store) RegisterType(i interface{}) {
|
||||
gob.Register(i)
|
||||
}
|
||||
|
||||
|
@ -70,11 +73,11 @@ func (s *Store) Get(c *fiber.Ctx) (*Session, error) {
|
|||
if raw != nil && err == nil {
|
||||
mux.Lock()
|
||||
defer mux.Unlock()
|
||||
_, _ = sess.byteBuffer.Write(raw)
|
||||
_, _ = sess.byteBuffer.Write(raw) //nolint:errcheck // This will never fail
|
||||
encCache := gob.NewDecoder(sess.byteBuffer)
|
||||
err := encCache.Decode(&sess.data.Data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to decode session data: %w", err)
|
||||
}
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package skip
|
||||
|
||||
import "github.com/gofiber/fiber/v2"
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// New creates a middleware handler which skips the wrapped handler
|
||||
// if the exclude predicate returns true.
|
||||
|
|
|
@ -17,7 +17,7 @@ func Test_Skip(t *testing.T) {
|
|||
app.Use(skip.New(errTeapotHandler, func(*fiber.Ctx) bool { return true }))
|
||||
app.Get("/", helloWorldHandler)
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
@ -30,7 +30,7 @@ func Test_SkipFalse(t *testing.T) {
|
|||
app.Use(skip.New(errTeapotHandler, func(*fiber.Ctx) bool { return false }))
|
||||
app.Get("/", helloWorldHandler)
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
|
||||
}
|
||||
|
@ -43,7 +43,7 @@ func Test_SkipNilFunc(t *testing.T) {
|
|||
app.Use(skip.New(errTeapotHandler, nil))
|
||||
app.Get("/", helloWorldHandler)
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
|
||||
}
|
||||
|
|
|
@ -18,7 +18,8 @@ func Test_Timeout(t *testing.T) {
|
|||
// fiber instance
|
||||
app := fiber.New()
|
||||
h := New(func(c *fiber.Ctx) error {
|
||||
sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms")
|
||||
sleepTime, err := time.ParseDuration(c.Params("sleepTime") + "ms")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
if err := sleepWithContext(c.UserContext(), sleepTime, context.DeadlineExceeded); err != nil {
|
||||
return fmt.Errorf("%w: l2 wrap", fmt.Errorf("%w: l1 wrap ", err))
|
||||
}
|
||||
|
@ -26,12 +27,12 @@ func Test_Timeout(t *testing.T) {
|
|||
}, 100*time.Millisecond)
|
||||
app.Get("/test/:sleepTime", h)
|
||||
testTimeout := func(timeoutStr string) {
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
|
||||
}
|
||||
testSucces := func(timeoutStr string) {
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
||||
}
|
||||
|
@ -49,7 +50,8 @@ func Test_TimeoutWithCustomError(t *testing.T) {
|
|||
// fiber instance
|
||||
app := fiber.New()
|
||||
h := New(func(c *fiber.Ctx) error {
|
||||
sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms")
|
||||
sleepTime, err := time.ParseDuration(c.Params("sleepTime") + "ms")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
if err := sleepWithContext(c.UserContext(), sleepTime, ErrFooTimeOut); err != nil {
|
||||
return fmt.Errorf("%w: execution error", err)
|
||||
}
|
||||
|
@ -57,12 +59,12 @@ func Test_TimeoutWithCustomError(t *testing.T) {
|
|||
}, 100*time.Millisecond, ErrFooTimeOut)
|
||||
app.Get("/test/:sleepTime", h)
|
||||
testTimeout := func(timeoutStr string) {
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
|
||||
}
|
||||
testSucces := func(timeoutStr string) {
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
||||
}
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue