mirror of https://github.com/gofiber/fiber.git
🚀 Feature: Add and apply more stricter golangci-lint linting rules (#2286)
* golangci-lint: add and apply more stricter linting rules * github: drop security workflow now that we use gosec linter inside golangci-lint * github: use official golangci-lint CI linter * Add editorconfig and gitattributes filepull/2313/head
parent
7327a17951
commit
167a8b5e94
|
@ -0,0 +1,8 @@
|
||||||
|
; This file is for unifying the coding style for different editors and IDEs.
|
||||||
|
; More information at http://editorconfig.org
|
||||||
|
; This style originates from https://github.com/fewagency/best-practices
|
||||||
|
root = true
|
||||||
|
|
||||||
|
[*]
|
||||||
|
charset = utf-8
|
||||||
|
end_of_line = lf
|
|
@ -0,0 +1,12 @@
|
||||||
|
# Handle line endings automatically for files detected as text
|
||||||
|
# and leave all files detected as binary untouched.
|
||||||
|
* text=auto eol=lf
|
||||||
|
|
||||||
|
# Force batch scripts to always use CRLF line endings so that if a repo is accessed
|
||||||
|
# in Windows via a file share from Linux, the scripts will work.
|
||||||
|
*.{cmd,[cC][mM][dD]} text eol=crlf
|
||||||
|
*.{bat,[bB][aA][tT]} text eol=crlf
|
||||||
|
|
||||||
|
# Force bash scripts to always use LF line endings so that if a repo is accessed
|
||||||
|
# in Unix via a file share from Windows, the scripts will work.
|
||||||
|
*.sh text eol=lf
|
|
@ -1,17 +1,28 @@
|
||||||
|
# Adapted from https://github.com/golangci/golangci-lint-action/blob/b56f6f529003f1c81d4d759be6bd5f10bf9a0fa0/README.md#how-to-use
|
||||||
|
|
||||||
|
name: golangci-lint
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
|
tags:
|
||||||
|
- v*
|
||||||
branches:
|
branches:
|
||||||
- master
|
- master
|
||||||
- main
|
- main
|
||||||
pull_request:
|
pull_request:
|
||||||
name: Linter
|
permissions:
|
||||||
|
contents: read
|
||||||
jobs:
|
jobs:
|
||||||
Golint:
|
golangci:
|
||||||
|
name: lint
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Fetch Repository
|
- uses: actions/setup-go@v3
|
||||||
uses: actions/checkout@v3
|
|
||||||
- name: Run Golint
|
|
||||||
uses: reviewdog/action-golangci-lint@v2
|
|
||||||
with:
|
with:
|
||||||
golangci_lint_flags: "--tests=false"
|
# NOTE: Keep this in sync with the version from go.mod
|
||||||
|
go-version: 1.19
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
- name: golangci-lint
|
||||||
|
uses: golangci/golangci-lint-action@v3
|
||||||
|
with:
|
||||||
|
# NOTE: Keep this in sync with the version from .golangci.yml
|
||||||
|
version: v1.50.1
|
||||||
|
|
|
@ -1,17 +0,0 @@
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- master
|
|
||||||
- main
|
|
||||||
pull_request:
|
|
||||||
name: Security
|
|
||||||
jobs:
|
|
||||||
Gosec:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Fetch Repository
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
- name: Run Gosec
|
|
||||||
uses: securego/gosec@master
|
|
||||||
with:
|
|
||||||
args: -exclude-dir=internal/*/ ./...
|
|
|
@ -0,0 +1,258 @@
|
||||||
|
# Created based on v1.50.1
|
||||||
|
# NOTE: Keep this in sync with the version in .github/workflows/linter.yml
|
||||||
|
|
||||||
|
run:
|
||||||
|
modules-download-mode: readonly
|
||||||
|
skip-dirs-use-default: false
|
||||||
|
skip-dirs:
|
||||||
|
- internal # TODO: Also apply proper linting for internal dir
|
||||||
|
|
||||||
|
output:
|
||||||
|
sort-results: true
|
||||||
|
|
||||||
|
linters-settings:
|
||||||
|
# TODO: Eventually enable these checks
|
||||||
|
# depguard:
|
||||||
|
# include-go-root: true
|
||||||
|
# packages:
|
||||||
|
# - flag
|
||||||
|
# - io/ioutil
|
||||||
|
# - reflect
|
||||||
|
# - unsafe
|
||||||
|
# packages-with-error-message:
|
||||||
|
# - flag: '`flag` package is only allowed in main.go'
|
||||||
|
# - io/ioutil: '`io/ioutil` package is deprecated, use the `io` and `os` package instead'
|
||||||
|
# - reflect: '`reflect` package is dangerous to use'
|
||||||
|
# - unsafe: '`unsafe` package is dangerous to use'
|
||||||
|
|
||||||
|
errcheck:
|
||||||
|
check-type-assertions: true
|
||||||
|
check-blank: true
|
||||||
|
disable-default-exclusions: true
|
||||||
|
|
||||||
|
errchkjson:
|
||||||
|
report-no-exported: true
|
||||||
|
|
||||||
|
exhaustive:
|
||||||
|
default-signifies-exhaustive: true
|
||||||
|
|
||||||
|
forbidigo:
|
||||||
|
forbid:
|
||||||
|
- ^(fmt\.Print(|f|ln)|print|println)$
|
||||||
|
- 'http\.Default(Client|Transport)'
|
||||||
|
# TODO: Eventually enable these patterns
|
||||||
|
# - 'time\.Sleep'
|
||||||
|
# - 'panic'
|
||||||
|
|
||||||
|
gci:
|
||||||
|
sections:
|
||||||
|
- standard
|
||||||
|
- prefix(github.com/gofiber/fiber)
|
||||||
|
- default
|
||||||
|
- blank
|
||||||
|
- dot
|
||||||
|
custom-order: true
|
||||||
|
|
||||||
|
gocritic:
|
||||||
|
disabled-checks:
|
||||||
|
- ifElseChain
|
||||||
|
|
||||||
|
gofumpt:
|
||||||
|
module-path: github.com/gofiber/fiber
|
||||||
|
extra-rules: true
|
||||||
|
|
||||||
|
gosec:
|
||||||
|
config:
|
||||||
|
global:
|
||||||
|
audit: true
|
||||||
|
|
||||||
|
govet:
|
||||||
|
check-shadowing: true
|
||||||
|
enable-all: true
|
||||||
|
disable:
|
||||||
|
- shadow
|
||||||
|
- fieldalignment
|
||||||
|
|
||||||
|
grouper:
|
||||||
|
import-require-single-import: true
|
||||||
|
import-require-grouping: true
|
||||||
|
|
||||||
|
misspell:
|
||||||
|
locale: US
|
||||||
|
|
||||||
|
nolintlint:
|
||||||
|
require-explanation: true
|
||||||
|
require-specific: true
|
||||||
|
|
||||||
|
nonamedreturns:
|
||||||
|
report-error-in-defer: true
|
||||||
|
|
||||||
|
predeclared:
|
||||||
|
q: true
|
||||||
|
|
||||||
|
promlinter:
|
||||||
|
strict: true
|
||||||
|
|
||||||
|
revive:
|
||||||
|
enable-all-rules: true
|
||||||
|
rules:
|
||||||
|
# Provided by gomnd linter
|
||||||
|
- name: add-constant
|
||||||
|
disabled: true
|
||||||
|
- name: argument-limit
|
||||||
|
disabled: true
|
||||||
|
# Provided by bidichk
|
||||||
|
- name: banned-characters
|
||||||
|
disabled: true
|
||||||
|
- name: cognitive-complexity
|
||||||
|
disabled: true
|
||||||
|
- name: cyclomatic
|
||||||
|
disabled: true
|
||||||
|
- name: exported
|
||||||
|
disabled: true
|
||||||
|
- name: file-header
|
||||||
|
disabled: true
|
||||||
|
- name: function-result-limit
|
||||||
|
disabled: true
|
||||||
|
- name: function-length
|
||||||
|
disabled: true
|
||||||
|
- name: line-length-limit
|
||||||
|
disabled: true
|
||||||
|
- name: max-public-structs
|
||||||
|
disabled: true
|
||||||
|
- name: modifies-parameter
|
||||||
|
disabled: true
|
||||||
|
- name: nested-structs
|
||||||
|
disabled: true
|
||||||
|
- name: package-comments
|
||||||
|
disabled: true
|
||||||
|
|
||||||
|
stylecheck:
|
||||||
|
checks:
|
||||||
|
- all
|
||||||
|
- -ST1000
|
||||||
|
- -ST1020
|
||||||
|
- -ST1021
|
||||||
|
- -ST1022
|
||||||
|
|
||||||
|
tagliatelle:
|
||||||
|
case:
|
||||||
|
rules:
|
||||||
|
json: snake
|
||||||
|
|
||||||
|
#tenv:
|
||||||
|
# all: true
|
||||||
|
|
||||||
|
#unparam:
|
||||||
|
# check-exported: true
|
||||||
|
|
||||||
|
wrapcheck:
|
||||||
|
ignorePackageGlobs:
|
||||||
|
- github.com/gofiber/fiber/*
|
||||||
|
- github.com/valyala/fasthttp
|
||||||
|
|
||||||
|
issues:
|
||||||
|
exclude-use-default: false
|
||||||
|
|
||||||
|
linters:
|
||||||
|
enable:
|
||||||
|
- asasalint
|
||||||
|
- asciicheck
|
||||||
|
- bidichk
|
||||||
|
- bodyclose
|
||||||
|
- containedctx
|
||||||
|
- contextcheck
|
||||||
|
# - cyclop
|
||||||
|
# - deadcode
|
||||||
|
# - decorder
|
||||||
|
- depguard
|
||||||
|
- dogsled
|
||||||
|
# - dupl
|
||||||
|
# - dupword
|
||||||
|
- durationcheck
|
||||||
|
- errcheck
|
||||||
|
- errchkjson
|
||||||
|
- errname
|
||||||
|
- errorlint
|
||||||
|
- execinquery
|
||||||
|
- exhaustive
|
||||||
|
# - exhaustivestruct
|
||||||
|
# - exhaustruct
|
||||||
|
- exportloopref
|
||||||
|
- forbidigo
|
||||||
|
- forcetypeassert
|
||||||
|
# - funlen
|
||||||
|
- gci
|
||||||
|
- gochecknoglobals
|
||||||
|
- gochecknoinits
|
||||||
|
# - gocognit
|
||||||
|
- goconst
|
||||||
|
- gocritic
|
||||||
|
# - gocyclo
|
||||||
|
# - godot
|
||||||
|
# - godox
|
||||||
|
# - goerr113
|
||||||
|
- gofmt
|
||||||
|
- gofumpt
|
||||||
|
# - goheader
|
||||||
|
- goimports
|
||||||
|
# - golint
|
||||||
|
- gomnd
|
||||||
|
- gomoddirectives
|
||||||
|
# - gomodguard
|
||||||
|
- goprintffuncname
|
||||||
|
- gosec
|
||||||
|
- gosimple
|
||||||
|
- govet
|
||||||
|
- grouper
|
||||||
|
# - ifshort
|
||||||
|
# - importas
|
||||||
|
- ineffassign
|
||||||
|
# - interfacebloat
|
||||||
|
# - interfacer
|
||||||
|
# - ireturn
|
||||||
|
# - lll
|
||||||
|
- loggercheck
|
||||||
|
# - maintidx
|
||||||
|
# - makezero
|
||||||
|
# - maligned
|
||||||
|
- misspell
|
||||||
|
- nakedret
|
||||||
|
# - nestif
|
||||||
|
- nilerr
|
||||||
|
- nilnil
|
||||||
|
# - nlreturn
|
||||||
|
- noctx
|
||||||
|
- nolintlint
|
||||||
|
- nonamedreturns
|
||||||
|
# - nosnakecase
|
||||||
|
- nosprintfhostport
|
||||||
|
# - paralleltest # TODO: Enable once https://github.com/gofiber/fiber/issues/2254 is implemented
|
||||||
|
# - prealloc
|
||||||
|
- predeclared
|
||||||
|
- promlinter
|
||||||
|
- reassign
|
||||||
|
- revive
|
||||||
|
- rowserrcheck
|
||||||
|
# - scopelint
|
||||||
|
- sqlclosecheck
|
||||||
|
- staticcheck
|
||||||
|
# - structcheck
|
||||||
|
- stylecheck
|
||||||
|
- tagliatelle
|
||||||
|
# - tenv # TODO: Enable once we drop support for Go 1.16
|
||||||
|
# - testableexamples
|
||||||
|
# - testpackage # TODO: Enable once https://github.com/gofiber/fiber/issues/2252 is implemented
|
||||||
|
- thelper
|
||||||
|
# - tparallel # TODO: Enable once https://github.com/gofiber/fiber/issues/2254 is implemented
|
||||||
|
- typecheck
|
||||||
|
- unconvert
|
||||||
|
- unparam
|
||||||
|
- unused
|
||||||
|
- usestdlibvars
|
||||||
|
# - varcheck
|
||||||
|
# - varnamelen
|
||||||
|
- wastedassign
|
||||||
|
- whitespace
|
||||||
|
- wrapcheck
|
||||||
|
# - wsl
|
58
app.go
58
app.go
|
@ -14,6 +14,7 @@ import (
|
||||||
"encoding/xml"
|
"encoding/xml"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
|
@ -24,6 +25,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2/utils"
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -306,7 +308,7 @@ type Config struct {
|
||||||
|
|
||||||
// FEATURE: v2.3.x
|
// FEATURE: v2.3.x
|
||||||
// The router executes the same handler by default if StrictRouting or CaseSensitive is disabled.
|
// The router executes the same handler by default if StrictRouting or CaseSensitive is disabled.
|
||||||
// Enabling RedirectFixedPath will change this behaviour into a client redirect to the original route path.
|
// Enabling RedirectFixedPath will change this behavior into a client redirect to the original route path.
|
||||||
// Using the status code 301 for GET requests and 308 for all other request methods.
|
// Using the status code 301 for GET requests and 308 for all other request methods.
|
||||||
//
|
//
|
||||||
// Default: false
|
// Default: false
|
||||||
|
@ -454,6 +456,8 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
// HTTP methods enabled by default
|
// HTTP methods enabled by default
|
||||||
|
//
|
||||||
|
//nolint:gochecknoglobals // Using a global var is fine here
|
||||||
var DefaultMethods = []string{
|
var DefaultMethods = []string{
|
||||||
MethodGet,
|
MethodGet,
|
||||||
MethodHead,
|
MethodHead,
|
||||||
|
@ -467,7 +471,7 @@ var DefaultMethods = []string{
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultErrorHandler that process return errors from handlers
|
// DefaultErrorHandler that process return errors from handlers
|
||||||
var DefaultErrorHandler = func(c *Ctx, err error) error {
|
func DefaultErrorHandler(c *Ctx, err error) error {
|
||||||
code := StatusInternalServerError
|
code := StatusInternalServerError
|
||||||
var e *Error
|
var e *Error
|
||||||
if errors.As(err, &e) {
|
if errors.As(err, &e) {
|
||||||
|
@ -519,7 +523,7 @@ func New(config ...Config) *App {
|
||||||
|
|
||||||
if app.config.ETag {
|
if app.config.ETag {
|
||||||
if !IsChild() {
|
if !IsChild() {
|
||||||
fmt.Println("[Warning] Config.ETag is deprecated since v2.0.6, please use 'middleware/etag'.")
|
log.Printf("[Warning] Config.ETag is deprecated since v2.0.6, please use 'middleware/etag'.\n")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -587,7 +591,7 @@ func (app *App) handleTrustedProxy(ipAddress string) {
|
||||||
if strings.Contains(ipAddress, "/") {
|
if strings.Contains(ipAddress, "/") {
|
||||||
_, ipNet, err := net.ParseCIDR(ipAddress)
|
_, ipNet, err := net.ParseCIDR(ipAddress)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("[Warning] IP range %q could not be parsed: %v\n", ipAddress, err)
|
log.Printf("[Warning] IP range %q could not be parsed: %v\n", ipAddress, err)
|
||||||
} else {
|
} else {
|
||||||
app.config.trustedProxyRanges = append(app.config.trustedProxyRanges, ipNet)
|
app.config.trustedProxyRanges = append(app.config.trustedProxyRanges, ipNet)
|
||||||
}
|
}
|
||||||
|
@ -822,7 +826,7 @@ func (app *App) Config() Config {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handler returns the server handler.
|
// Handler returns the server handler.
|
||||||
func (app *App) Handler() fasthttp.RequestHandler {
|
func (app *App) Handler() fasthttp.RequestHandler { //revive:disable-line:confusing-naming // Having both a Handler() (uppercase) and a handler() (lowercase) is fine. TODO: Use nolint:revive directive instead. See https://github.com/golangci/golangci-lint/issues/3476
|
||||||
// prepare the server for the start
|
// prepare the server for the start
|
||||||
app.startupProcess()
|
app.startupProcess()
|
||||||
return app.handler
|
return app.handler
|
||||||
|
@ -887,7 +891,7 @@ func (app *App) Hooks() *Hooks {
|
||||||
|
|
||||||
// Test is used for internal debugging by passing a *http.Request.
|
// Test is used for internal debugging by passing a *http.Request.
|
||||||
// Timeout is optional and defaults to 1s, -1 will disable it completely.
|
// Timeout is optional and defaults to 1s, -1 will disable it completely.
|
||||||
func (app *App) Test(req *http.Request, msTimeout ...int) (resp *http.Response, err error) {
|
func (app *App) Test(req *http.Request, msTimeout ...int) (*http.Response, error) {
|
||||||
// Set timeout
|
// Set timeout
|
||||||
timeout := 1000
|
timeout := 1000
|
||||||
if len(msTimeout) > 0 {
|
if len(msTimeout) > 0 {
|
||||||
|
@ -902,15 +906,15 @@ func (app *App) Test(req *http.Request, msTimeout ...int) (resp *http.Response,
|
||||||
// Dump raw http request
|
// Dump raw http request
|
||||||
dump, err := httputil.DumpRequest(req, true)
|
dump, err := httputil.DumpRequest(req, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to dump request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create test connection
|
// Create test connection
|
||||||
conn := new(testConn)
|
conn := new(testConn)
|
||||||
|
|
||||||
// Write raw http request
|
// Write raw http request
|
||||||
if _, err = conn.r.Write(dump); err != nil {
|
if _, err := conn.r.Write(dump); err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to write: %w", err)
|
||||||
}
|
}
|
||||||
// prepare the server for the start
|
// prepare the server for the start
|
||||||
app.startupProcess()
|
app.startupProcess()
|
||||||
|
@ -943,7 +947,7 @@ func (app *App) Test(req *http.Request, msTimeout ...int) (resp *http.Response,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for errors
|
// Check for errors
|
||||||
if err != nil && err != fasthttp.ErrGetOnly {
|
if err != nil && !errors.Is(err, fasthttp.ErrGetOnly) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -951,12 +955,17 @@ func (app *App) Test(req *http.Request, msTimeout ...int) (resp *http.Response,
|
||||||
buffer := bufio.NewReader(&conn.w)
|
buffer := bufio.NewReader(&conn.w)
|
||||||
|
|
||||||
// Convert raw http response to *http.Response
|
// Convert raw http response to *http.Response
|
||||||
return http.ReadResponse(buffer, req)
|
res, err := http.ReadResponse(buffer, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type disableLogger struct{}
|
type disableLogger struct{}
|
||||||
|
|
||||||
func (dl *disableLogger) Printf(_ string, _ ...interface{}) {
|
func (*disableLogger) Printf(_ string, _ ...interface{}) {
|
||||||
// fmt.Println(fmt.Sprintf(format, args...))
|
// fmt.Println(fmt.Sprintf(format, args...))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -967,7 +976,7 @@ func (app *App) init() *App {
|
||||||
// Only load templates if a view engine is specified
|
// Only load templates if a view engine is specified
|
||||||
if app.config.Views != nil {
|
if app.config.Views != nil {
|
||||||
if err := app.config.Views.Load(); err != nil {
|
if err := app.config.Views.Load(); err != nil {
|
||||||
fmt.Printf("views: %v\n", err)
|
log.Printf("[Warning]: failed to load views: %v\n", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1039,25 +1048,30 @@ func (app *App) ErrorHandler(ctx *Ctx, err error) error {
|
||||||
// errors before calling the application's error handler method.
|
// errors before calling the application's error handler method.
|
||||||
func (app *App) serverErrorHandler(fctx *fasthttp.RequestCtx, err error) {
|
func (app *App) serverErrorHandler(fctx *fasthttp.RequestCtx, err error) {
|
||||||
c := app.AcquireCtx(fctx)
|
c := app.AcquireCtx(fctx)
|
||||||
if _, ok := err.(*fasthttp.ErrSmallBuffer); ok {
|
defer app.ReleaseCtx(c)
|
||||||
|
|
||||||
|
var errNetOP *net.OpError
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case errors.As(err, new(*fasthttp.ErrSmallBuffer)):
|
||||||
err = ErrRequestHeaderFieldsTooLarge
|
err = ErrRequestHeaderFieldsTooLarge
|
||||||
} else if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() {
|
case errors.As(err, &errNetOP) && errNetOP.Timeout():
|
||||||
err = ErrRequestTimeout
|
err = ErrRequestTimeout
|
||||||
} else if err == fasthttp.ErrBodyTooLarge {
|
case errors.Is(err, fasthttp.ErrBodyTooLarge):
|
||||||
err = ErrRequestEntityTooLarge
|
err = ErrRequestEntityTooLarge
|
||||||
} else if err == fasthttp.ErrGetOnly {
|
case errors.Is(err, fasthttp.ErrGetOnly):
|
||||||
err = ErrMethodNotAllowed
|
err = ErrMethodNotAllowed
|
||||||
} else if strings.Contains(err.Error(), "timeout") {
|
case strings.Contains(err.Error(), "timeout"):
|
||||||
err = ErrRequestTimeout
|
err = ErrRequestTimeout
|
||||||
} else {
|
default:
|
||||||
err = NewError(StatusBadRequest, err.Error())
|
err = NewError(StatusBadRequest, err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
if catch := app.ErrorHandler(c, err); catch != nil {
|
if catch := app.ErrorHandler(c, err); catch != nil {
|
||||||
_ = c.SendStatus(StatusInternalServerError)
|
log.Printf("serverErrorHandler: failed to call ErrorHandler: %v\n", catch)
|
||||||
|
_ = c.SendStatus(StatusInternalServerError) //nolint:errcheck // It is fine to ignore the error here
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
app.ReleaseCtx(c)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// startupProcess Is the method which executes all the necessary processes just before the start of the server.
|
// startupProcess Is the method which executes all the necessary processes just before the start of the server.
|
||||||
|
|
65
app_test.go
65
app_test.go
|
@ -2,6 +2,7 @@
|
||||||
// 🤖 Github Repository: https://github.com/gofiber/fiber
|
// 🤖 Github Repository: https://github.com/gofiber/fiber
|
||||||
// 📌 API Documentation: https://docs.gofiber.io
|
// 📌 API Documentation: https://docs.gofiber.io
|
||||||
|
|
||||||
|
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
|
||||||
package fiber
|
package fiber
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
@ -23,15 +24,16 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2/utils"
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
"github.com/valyala/fasthttp/fasthttputil"
|
"github.com/valyala/fasthttp/fasthttputil"
|
||||||
)
|
)
|
||||||
|
|
||||||
var testEmptyHandler = func(c *Ctx) error {
|
func testEmptyHandler(_ *Ctx) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func testStatus200(t *testing.T, app *App, url string, method string) {
|
func testStatus200(t *testing.T, app *App, url, method string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
req := httptest.NewRequest(method, url, nil)
|
req := httptest.NewRequest(method, url, nil)
|
||||||
|
@ -42,6 +44,8 @@ func testStatus200(t *testing.T, app *App, url string, method string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func testErrorResponse(t *testing.T, err error, resp *http.Response, expectedBodyError string) {
|
func testErrorResponse(t *testing.T, err error, resp *http.Response, expectedBodyError string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||||
utils.AssertEqual(t, 500, resp.StatusCode, "Status code")
|
utils.AssertEqual(t, 500, resp.StatusCode, "Status code")
|
||||||
|
|
||||||
|
@ -140,7 +144,6 @@ func Test_App_ServerErrorHandler_SmallReadBuffer(t *testing.T) {
|
||||||
logHeaderSlice := make([]string, 5000)
|
logHeaderSlice := make([]string, 5000)
|
||||||
request.Header.Set("Very-Long-Header", strings.Join(logHeaderSlice, "-"))
|
request.Header.Set("Very-Long-Header", strings.Join(logHeaderSlice, "-"))
|
||||||
_, err := app.Test(request)
|
_, err := app.Test(request)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Expect an error at app.Test(request)")
|
t.Error("Expect an error at app.Test(request)")
|
||||||
}
|
}
|
||||||
|
@ -470,7 +473,6 @@ func Test_App_Use_MultiplePrefix(t *testing.T) {
|
||||||
body, err = io.ReadAll(resp.Body)
|
body, err = io.ReadAll(resp.Body)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, "/test/doe", string(body))
|
utils.AssertEqual(t, "/test/doe", string(body))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_App_Use_StrictRouting(t *testing.T) {
|
func Test_App_Use_StrictRouting(t *testing.T) {
|
||||||
|
@ -515,7 +517,7 @@ func Test_App_Add_Method_Test(t *testing.T) {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
methods := append(DefaultMethods, "JOHN")
|
methods := append(DefaultMethods, "JOHN") //nolint:gocritic // We want a new slice here
|
||||||
app := New(Config{
|
app := New(Config{
|
||||||
RequestMethods: methods,
|
RequestMethods: methods,
|
||||||
})
|
})
|
||||||
|
@ -780,7 +782,7 @@ func Test_App_ShutdownWithTimeout(t *testing.T) {
|
||||||
case <-timer.C:
|
case <-timer.C:
|
||||||
t.Fatal("idle connections not closed on shutdown")
|
t.Fatal("idle connections not closed on shutdown")
|
||||||
case err := <-shutdownErr:
|
case err := <-shutdownErr:
|
||||||
if err == nil || err != context.DeadlineExceeded {
|
if err == nil || !errors.Is(err, context.DeadlineExceeded) {
|
||||||
t.Fatalf("unexpected err %v. Expecting %v", err, context.DeadlineExceeded)
|
t.Fatalf("unexpected err %v. Expecting %v", err, context.DeadlineExceeded)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -852,7 +854,7 @@ func Test_App_Static_MaxAge(t *testing.T) {
|
||||||
|
|
||||||
app.Static("/", "./.github", Static{MaxAge: 100})
|
app.Static("/", "./.github", Static{MaxAge: 100})
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/index.html", nil))
|
resp, err := app.Test(httptest.NewRequest(MethodGet, "/index.html", nil))
|
||||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||||
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
|
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
|
||||||
utils.AssertEqual(t, false, resp.Header.Get(HeaderContentLength) == "")
|
utils.AssertEqual(t, false, resp.Header.Get(HeaderContentLength) == "")
|
||||||
|
@ -866,19 +868,19 @@ func Test_App_Static_Custom_CacheControl(t *testing.T) {
|
||||||
app := New()
|
app := New()
|
||||||
|
|
||||||
app.Static("/", "./.github", Static{ModifyResponse: func(c *Ctx) error {
|
app.Static("/", "./.github", Static{ModifyResponse: func(c *Ctx) error {
|
||||||
if strings.Contains(string(c.GetRespHeader("Content-Type")), "text/html") {
|
if strings.Contains(c.GetRespHeader("Content-Type"), "text/html") {
|
||||||
c.Response().Header.Set("Cache-Control", "no-cache, no-store, must-revalidate")
|
c.Response().Header.Set("Cache-Control", "no-cache, no-store, must-revalidate")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}})
|
}})
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/index.html", nil))
|
resp, err := app.Test(httptest.NewRequest(MethodGet, "/index.html", nil))
|
||||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||||
utils.AssertEqual(t, "no-cache, no-store, must-revalidate", resp.Header.Get(HeaderCacheControl), "CacheControl Control")
|
utils.AssertEqual(t, "no-cache, no-store, must-revalidate", resp.Header.Get(HeaderCacheControl), "CacheControl Control")
|
||||||
|
|
||||||
normal_resp, normal_err := app.Test(httptest.NewRequest("GET", "/config.yml", nil))
|
respNormal, errNormal := app.Test(httptest.NewRequest(MethodGet, "/config.yml", nil))
|
||||||
utils.AssertEqual(t, nil, normal_err, "app.Test(req)")
|
utils.AssertEqual(t, nil, errNormal, "app.Test(req)")
|
||||||
utils.AssertEqual(t, "", normal_resp.Header.Get(HeaderCacheControl), "CacheControl Control")
|
utils.AssertEqual(t, "", respNormal.Header.Get(HeaderCacheControl), "CacheControl Control")
|
||||||
}
|
}
|
||||||
|
|
||||||
// go test -run Test_App_Static_Download
|
// go test -run Test_App_Static_Download
|
||||||
|
@ -890,7 +892,7 @@ func Test_App_Static_Download(t *testing.T) {
|
||||||
|
|
||||||
app.Static("/fiber.png", "./.github/testdata/fs/img/fiber.png", Static{Download: true})
|
app.Static("/fiber.png", "./.github/testdata/fs/img/fiber.png", Static{Download: true})
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/fiber.png", nil))
|
resp, err := app.Test(httptest.NewRequest(MethodGet, "/fiber.png", nil))
|
||||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||||
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
|
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
|
||||||
utils.AssertEqual(t, false, resp.Header.Get(HeaderContentLength) == "")
|
utils.AssertEqual(t, false, resp.Header.Get(HeaderContentLength) == "")
|
||||||
|
@ -1067,7 +1069,7 @@ func Test_App_Static_Next(t *testing.T) {
|
||||||
|
|
||||||
t.Run("app.Static is skipped: invoking Get handler", func(t *testing.T) {
|
t.Run("app.Static is skipped: invoking Get handler", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest(MethodGet, "/", nil)
|
||||||
req.Header.Set("X-Custom-Header", "skip")
|
req.Header.Set("X-Custom-Header", "skip")
|
||||||
resp, err := app.Test(req)
|
resp, err := app.Test(req)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
@ -1082,7 +1084,7 @@ func Test_App_Static_Next(t *testing.T) {
|
||||||
|
|
||||||
t.Run("app.Static is not skipped: serving index.html", func(t *testing.T) {
|
t.Run("app.Static is not skipped: serving index.html", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest(MethodGet, "/", nil)
|
||||||
req.Header.Set("X-Custom-Header", "don't skip")
|
req.Header.Set("X-Custom-Header", "don't skip")
|
||||||
resp, err := app.Test(req)
|
resp, err := app.Test(req)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
@ -1379,7 +1381,7 @@ func Test_Test_DumpError(t *testing.T) {
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest(MethodGet, "/", errorReader(0)))
|
resp, err := app.Test(httptest.NewRequest(MethodGet, "/", errorReader(0)))
|
||||||
utils.AssertEqual(t, true, resp == nil)
|
utils.AssertEqual(t, true, resp == nil)
|
||||||
utils.AssertEqual(t, "errorReader", err.Error())
|
utils.AssertEqual(t, "failed to dump request: errorReader", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
// go test -run Test_App_Handler
|
// go test -run Test_App_Handler
|
||||||
|
@ -1393,7 +1395,7 @@ type invalidView struct{}
|
||||||
|
|
||||||
func (invalidView) Load() error { return errors.New("invalid view") }
|
func (invalidView) Load() error { return errors.New("invalid view") }
|
||||||
|
|
||||||
func (i invalidView) Render(io.Writer, string, interface{}, ...string) error { panic("implement me") }
|
func (invalidView) Render(io.Writer, string, interface{}, ...string) error { panic("implement me") }
|
||||||
|
|
||||||
// go test -run Test_App_Init_Error_View
|
// go test -run Test_App_Init_Error_View
|
||||||
func Test_App_Init_Error_View(t *testing.T) {
|
func Test_App_Init_Error_View(t *testing.T) {
|
||||||
|
@ -1405,7 +1407,9 @@ func Test_App_Init_Error_View(t *testing.T) {
|
||||||
utils.AssertEqual(t, "implement me", fmt.Sprintf("%v", err))
|
utils.AssertEqual(t, "implement me", fmt.Sprintf("%v", err))
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
_ = app.config.Views.Render(nil, "", nil)
|
|
||||||
|
err := app.config.Views.Render(nil, "", nil)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// go test -run Test_App_Stack
|
// go test -run Test_App_Stack
|
||||||
|
@ -1535,11 +1539,12 @@ func Test_App_SmallReadBuffer(t *testing.T) {
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(500 * time.Millisecond)
|
time.Sleep(500 * time.Millisecond)
|
||||||
resp, err := http.Get("http://127.0.0.1:4006/small-read-buffer")
|
req, err := http.NewRequestWithContext(context.Background(), MethodGet, "http://127.0.0.1:4006/small-read-buffer", http.NoBody)
|
||||||
if resp != nil {
|
|
||||||
utils.AssertEqual(t, 431, resp.StatusCode)
|
|
||||||
}
|
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
var client http.Client
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
utils.AssertEqual(t, 431, resp.StatusCode)
|
||||||
utils.AssertEqual(t, nil, app.Shutdown())
|
utils.AssertEqual(t, nil, app.Shutdown())
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -1572,13 +1577,13 @@ func Test_App_New_Test_Parallel(t *testing.T) {
|
||||||
t.Run("Test_App_New_Test_Parallel_1", func(t *testing.T) {
|
t.Run("Test_App_New_Test_Parallel_1", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
app := New(Config{Immutable: true})
|
app := New(Config{Immutable: true})
|
||||||
_, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
_, err := app.Test(httptest.NewRequest(MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
})
|
})
|
||||||
t.Run("Test_App_New_Test_Parallel_2", func(t *testing.T) {
|
t.Run("Test_App_New_Test_Parallel_2", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
app := New(Config{Immutable: true})
|
app := New(Config{Immutable: true})
|
||||||
_, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
_, err := app.Test(httptest.NewRequest(MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -1591,7 +1596,7 @@ func Test_App_ReadBodyStream(t *testing.T) {
|
||||||
return c.SendString(fmt.Sprintf("%v %s", c.Request().IsBodyStream(), c.Body()))
|
return c.SendString(fmt.Sprintf("%v %s", c.Request().IsBodyStream(), c.Body()))
|
||||||
})
|
})
|
||||||
testString := "this is a test"
|
testString := "this is a test"
|
||||||
resp, err := app.Test(httptest.NewRequest("POST", "/", bytes.NewBufferString(testString)))
|
resp, err := app.Test(httptest.NewRequest(MethodPost, "/", bytes.NewBufferString(testString)))
|
||||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
utils.AssertEqual(t, nil, err, "io.ReadAll(resp.Body)")
|
utils.AssertEqual(t, nil, err, "io.ReadAll(resp.Body)")
|
||||||
|
@ -1615,12 +1620,12 @@ func Test_App_DisablePreParseMultipartForm(t *testing.T) {
|
||||||
}
|
}
|
||||||
file, err := mpf.File["test"][0].Open()
|
file, err := mpf.File["test"][0].Open()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to open: %w", err)
|
||||||
}
|
}
|
||||||
buffer := make([]byte, len(testString))
|
buffer := make([]byte, len(testString))
|
||||||
n, err := file.Read(buffer)
|
n, err := file.Read(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to read: %w", err)
|
||||||
}
|
}
|
||||||
if n != len(testString) {
|
if n != len(testString) {
|
||||||
return fmt.Errorf("bad read length")
|
return fmt.Errorf("bad read length")
|
||||||
|
@ -1636,7 +1641,7 @@ func Test_App_DisablePreParseMultipartForm(t *testing.T) {
|
||||||
utils.AssertEqual(t, len(testString), n, "writer n")
|
utils.AssertEqual(t, len(testString), n, "writer n")
|
||||||
utils.AssertEqual(t, nil, w.Close(), "w.Close()")
|
utils.AssertEqual(t, nil, w.Close(), "w.Close()")
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/", b)
|
req := httptest.NewRequest(MethodPost, "/", b)
|
||||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||||
resp, err := app.Test(req)
|
resp, err := app.Test(req)
|
||||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||||
|
@ -1659,7 +1664,7 @@ func Test_App_Test_no_timeout_infinitely(t *testing.T) {
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
|
req := httptest.NewRequest(MethodGet, "/", http.NoBody)
|
||||||
_, err = app.Test(req, -1)
|
_, err = app.Test(req, -1)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -1696,7 +1701,7 @@ func Test_App_SetTLSHandler(t *testing.T) {
|
||||||
|
|
||||||
func Test_App_AddCustomRequestMethod(t *testing.T) {
|
func Test_App_AddCustomRequestMethod(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
methods := append(DefaultMethods, "TEST")
|
methods := append(DefaultMethods, "TEST") //nolint:gocritic // We want a new slice here
|
||||||
app := New(Config{
|
app := New(Config{
|
||||||
RequestMethods: methods,
|
RequestMethods: methods,
|
||||||
})
|
})
|
||||||
|
|
83
client.go
83
client.go
|
@ -15,6 +15,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2/utils"
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -50,6 +51,7 @@ type Args = fasthttp.Args
|
||||||
// Copy from fasthttp
|
// Copy from fasthttp
|
||||||
type RetryIfFunc = fasthttp.RetryIfFunc
|
type RetryIfFunc = fasthttp.RetryIfFunc
|
||||||
|
|
||||||
|
//nolint:gochecknoglobals // TODO: Do not use a global var here
|
||||||
var defaultClient Client
|
var defaultClient Client
|
||||||
|
|
||||||
// Client implements http client.
|
// Client implements http client.
|
||||||
|
@ -186,11 +188,11 @@ func (a *Agent) Parse() error {
|
||||||
|
|
||||||
uri := a.req.URI()
|
uri := a.req.URI()
|
||||||
|
|
||||||
isTLS := false
|
var isTLS bool
|
||||||
scheme := uri.Scheme()
|
scheme := uri.Scheme()
|
||||||
if bytes.Equal(scheme, strHTTPS) {
|
if bytes.Equal(scheme, []byte(schemeHTTPS)) {
|
||||||
isTLS = true
|
isTLS = true
|
||||||
} else if !bytes.Equal(scheme, strHTTP) {
|
} else if !bytes.Equal(scheme, []byte(schemeHTTP)) {
|
||||||
return fmt.Errorf("unsupported protocol %q. http and https are supported", scheme)
|
return fmt.Errorf("unsupported protocol %q. http and https are supported", scheme)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -241,7 +243,7 @@ func (a *Agent) SetBytesV(k string, v []byte) *Agent {
|
||||||
// SetBytesKV sets the given 'key: value' header.
|
// SetBytesKV sets the given 'key: value' header.
|
||||||
//
|
//
|
||||||
// Use AddBytesKV for setting multiple header values under the same key.
|
// Use AddBytesKV for setting multiple header values under the same key.
|
||||||
func (a *Agent) SetBytesKV(k []byte, v []byte) *Agent {
|
func (a *Agent) SetBytesKV(k, v []byte) *Agent {
|
||||||
a.req.Header.SetBytesKV(k, v)
|
a.req.Header.SetBytesKV(k, v)
|
||||||
|
|
||||||
return a
|
return a
|
||||||
|
@ -281,7 +283,7 @@ func (a *Agent) AddBytesV(k string, v []byte) *Agent {
|
||||||
//
|
//
|
||||||
// Multiple headers with the same key may be added with this function.
|
// Multiple headers with the same key may be added with this function.
|
||||||
// Use SetBytesKV for setting a single header for the given key.
|
// Use SetBytesKV for setting a single header for the given key.
|
||||||
func (a *Agent) AddBytesKV(k []byte, v []byte) *Agent {
|
func (a *Agent) AddBytesKV(k, v []byte) *Agent {
|
||||||
a.req.Header.AddBytesKV(k, v)
|
a.req.Header.AddBytesKV(k, v)
|
||||||
|
|
||||||
return a
|
return a
|
||||||
|
@ -652,10 +654,8 @@ func (a *Agent) Reuse() *Agent {
|
||||||
// certificate chain and host name.
|
// certificate chain and host name.
|
||||||
func (a *Agent) InsecureSkipVerify() *Agent {
|
func (a *Agent) InsecureSkipVerify() *Agent {
|
||||||
if a.HostClient.TLSConfig == nil {
|
if a.HostClient.TLSConfig == nil {
|
||||||
/* #nosec G402 */
|
a.HostClient.TLSConfig = &tls.Config{InsecureSkipVerify: true} //nolint:gosec // We explicitly let the user set insecure mode here
|
||||||
a.HostClient.TLSConfig = &tls.Config{InsecureSkipVerify: true} // #nosec G402
|
|
||||||
} else {
|
} else {
|
||||||
/* #nosec G402 */
|
|
||||||
a.HostClient.TLSConfig.InsecureSkipVerify = true
|
a.HostClient.TLSConfig.InsecureSkipVerify = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -728,14 +728,14 @@ func (a *Agent) RetryIf(retryIf RetryIfFunc) *Agent {
|
||||||
// Bytes returns the status code, bytes body and errors of url.
|
// Bytes returns the status code, bytes body and errors of url.
|
||||||
//
|
//
|
||||||
// it's not safe to use Agent after calling [Agent.Bytes]
|
// it's not safe to use Agent after calling [Agent.Bytes]
|
||||||
func (a *Agent) Bytes() (code int, body []byte, errs []error) {
|
func (a *Agent) Bytes() (int, []byte, []error) {
|
||||||
defer a.release()
|
defer a.release()
|
||||||
return a.bytes()
|
return a.bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Agent) bytes() (code int, body []byte, errs []error) {
|
func (a *Agent) bytes() (code int, body []byte, errs []error) { //nolint:nonamedreturns,revive // We want to overwrite the body in a deferred func. TODO: Check if we really need to do this. We eventually want to get rid of all named returns.
|
||||||
if errs = append(errs, a.errs...); len(errs) > 0 {
|
if errs = append(errs, a.errs...); len(errs) > 0 {
|
||||||
return
|
return code, body, errs
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -760,7 +760,7 @@ func (a *Agent) bytes() (code int, body []byte, errs []error) {
|
||||||
code = resp.StatusCode()
|
code = resp.StatusCode()
|
||||||
}
|
}
|
||||||
|
|
||||||
body = append(a.dest, resp.Body()...)
|
body = append(a.dest, resp.Body()...) //nolint:gocritic // We want to append to the returned slice here
|
||||||
|
|
||||||
if nilResp {
|
if nilResp {
|
||||||
ReleaseResponse(resp)
|
ReleaseResponse(resp)
|
||||||
|
@ -770,25 +770,25 @@ func (a *Agent) bytes() (code int, body []byte, errs []error) {
|
||||||
if a.timeout > 0 {
|
if a.timeout > 0 {
|
||||||
if err := a.HostClient.DoTimeout(req, resp, a.timeout); err != nil {
|
if err := a.HostClient.DoTimeout(req, resp, a.timeout); err != nil {
|
||||||
errs = append(errs, err)
|
errs = append(errs, err)
|
||||||
return
|
return code, body, errs
|
||||||
}
|
}
|
||||||
} else if a.maxRedirectsCount > 0 && (string(req.Header.Method()) == MethodGet || string(req.Header.Method()) == MethodHead) {
|
} else if a.maxRedirectsCount > 0 && (string(req.Header.Method()) == MethodGet || string(req.Header.Method()) == MethodHead) {
|
||||||
if err := a.HostClient.DoRedirects(req, resp, a.maxRedirectsCount); err != nil {
|
if err := a.HostClient.DoRedirects(req, resp, a.maxRedirectsCount); err != nil {
|
||||||
errs = append(errs, err)
|
errs = append(errs, err)
|
||||||
return
|
return code, body, errs
|
||||||
}
|
}
|
||||||
} else if err := a.HostClient.Do(req, resp); err != nil {
|
} else if err := a.HostClient.Do(req, resp); err != nil {
|
||||||
errs = append(errs, err)
|
errs = append(errs, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return code, body, errs
|
||||||
}
|
}
|
||||||
|
|
||||||
func printDebugInfo(req *Request, resp *Response, w io.Writer) {
|
func printDebugInfo(req *Request, resp *Response, w io.Writer) {
|
||||||
msg := fmt.Sprintf("Connected to %s(%s)\r\n\r\n", req.URI().Host(), resp.RemoteAddr())
|
msg := fmt.Sprintf("Connected to %s(%s)\r\n\r\n", req.URI().Host(), resp.RemoteAddr())
|
||||||
_, _ = w.Write(utils.UnsafeBytes(msg))
|
_, _ = w.Write(utils.UnsafeBytes(msg)) //nolint:errcheck // This will never fail
|
||||||
_, _ = req.WriteTo(w)
|
_, _ = req.WriteTo(w) //nolint:errcheck // This will never fail
|
||||||
_, _ = resp.WriteTo(w)
|
_, _ = resp.WriteTo(w) //nolint:errcheck // This will never fail
|
||||||
}
|
}
|
||||||
|
|
||||||
// String returns the status code, string body and errors of url.
|
// String returns the status code, string body and errors of url.
|
||||||
|
@ -797,6 +797,7 @@ func printDebugInfo(req *Request, resp *Response, w io.Writer) {
|
||||||
func (a *Agent) String() (int, string, []error) {
|
func (a *Agent) String() (int, string, []error) {
|
||||||
defer a.release()
|
defer a.release()
|
||||||
code, body, errs := a.bytes()
|
code, body, errs := a.bytes()
|
||||||
|
// TODO: There might be a data race here on body. Maybe use utils.CopyBytes on it?
|
||||||
|
|
||||||
return code, utils.UnsafeString(body), errs
|
return code, utils.UnsafeString(body), errs
|
||||||
}
|
}
|
||||||
|
@ -805,12 +806,15 @@ func (a *Agent) String() (int, string, []error) {
|
||||||
// And bytes body will be unmarshalled to given v.
|
// And bytes body will be unmarshalled to given v.
|
||||||
//
|
//
|
||||||
// it's not safe to use Agent after calling [Agent.Struct]
|
// it's not safe to use Agent after calling [Agent.Struct]
|
||||||
func (a *Agent) Struct(v interface{}) (code int, body []byte, errs []error) {
|
func (a *Agent) Struct(v interface{}) (int, []byte, []error) {
|
||||||
defer a.release()
|
defer a.release()
|
||||||
if code, body, errs = a.bytes(); len(errs) > 0 {
|
|
||||||
return
|
code, body, errs := a.bytes()
|
||||||
|
if len(errs) > 0 {
|
||||||
|
return code, body, errs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: This should only be done once
|
||||||
if a.jsonDecoder == nil {
|
if a.jsonDecoder == nil {
|
||||||
a.jsonDecoder = json.Unmarshal
|
a.jsonDecoder = json.Unmarshal
|
||||||
}
|
}
|
||||||
|
@ -819,7 +823,7 @@ func (a *Agent) Struct(v interface{}) (code int, body []byte, errs []error) {
|
||||||
errs = append(errs, err)
|
errs = append(errs, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return code, body, errs
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Agent) release() {
|
func (a *Agent) release() {
|
||||||
|
@ -855,6 +859,7 @@ func (a *Agent) reset() {
|
||||||
a.formFiles = a.formFiles[:0]
|
a.formFiles = a.formFiles[:0]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//nolint:gochecknoglobals // TODO: Do not use global vars here
|
||||||
var (
|
var (
|
||||||
clientPool sync.Pool
|
clientPool sync.Pool
|
||||||
agentPool = sync.Pool{
|
agentPool = sync.Pool{
|
||||||
|
@ -877,7 +882,11 @@ func AcquireClient() *Client {
|
||||||
if v == nil {
|
if v == nil {
|
||||||
return &Client{}
|
return &Client{}
|
||||||
}
|
}
|
||||||
return v.(*Client)
|
c, ok := v.(*Client)
|
||||||
|
if !ok {
|
||||||
|
panic(fmt.Errorf("failed to type-assert to *Client"))
|
||||||
|
}
|
||||||
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReleaseClient returns c acquired via AcquireClient to client pool.
|
// ReleaseClient returns c acquired via AcquireClient to client pool.
|
||||||
|
@ -899,7 +908,11 @@ func ReleaseClient(c *Client) {
|
||||||
// no longer needed. This allows Agent recycling, reduces GC pressure
|
// no longer needed. This allows Agent recycling, reduces GC pressure
|
||||||
// and usually improves performance.
|
// and usually improves performance.
|
||||||
func AcquireAgent() *Agent {
|
func AcquireAgent() *Agent {
|
||||||
return agentPool.Get().(*Agent)
|
a, ok := agentPool.Get().(*Agent)
|
||||||
|
if !ok {
|
||||||
|
panic(fmt.Errorf("failed to type-assert to *Agent"))
|
||||||
|
}
|
||||||
|
return a
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReleaseAgent returns a acquired via AcquireAgent to Agent pool.
|
// ReleaseAgent returns a acquired via AcquireAgent to Agent pool.
|
||||||
|
@ -922,7 +935,11 @@ func AcquireResponse() *Response {
|
||||||
if v == nil {
|
if v == nil {
|
||||||
return &Response{}
|
return &Response{}
|
||||||
}
|
}
|
||||||
return v.(*Response)
|
r, ok := v.(*Response)
|
||||||
|
if !ok {
|
||||||
|
panic(fmt.Errorf("failed to type-assert to *Response"))
|
||||||
|
}
|
||||||
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReleaseResponse return resp acquired via AcquireResponse to response pool.
|
// ReleaseResponse return resp acquired via AcquireResponse to response pool.
|
||||||
|
@ -945,7 +962,11 @@ func AcquireArgs() *Args {
|
||||||
if v == nil {
|
if v == nil {
|
||||||
return &Args{}
|
return &Args{}
|
||||||
}
|
}
|
||||||
return v.(*Args)
|
a, ok := v.(*Args)
|
||||||
|
if !ok {
|
||||||
|
panic(fmt.Errorf("failed to type-assert to *Args"))
|
||||||
|
}
|
||||||
|
return a
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReleaseArgs returns the object acquired via AcquireArgs to the pool.
|
// ReleaseArgs returns the object acquired via AcquireArgs to the pool.
|
||||||
|
@ -966,7 +987,11 @@ func AcquireFormFile() *FormFile {
|
||||||
if v == nil {
|
if v == nil {
|
||||||
return &FormFile{}
|
return &FormFile{}
|
||||||
}
|
}
|
||||||
return v.(*FormFile)
|
ff, ok := v.(*FormFile)
|
||||||
|
if !ok {
|
||||||
|
panic(fmt.Errorf("failed to type-assert to *FormFile"))
|
||||||
|
}
|
||||||
|
return ff
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReleaseFormFile returns the object acquired via AcquireFormFile to the pool.
|
// ReleaseFormFile returns the object acquired via AcquireFormFile to the pool.
|
||||||
|
@ -981,9 +1006,7 @@ func ReleaseFormFile(ff *FormFile) {
|
||||||
formFilePool.Put(ff)
|
formFilePool.Put(ff)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
const (
|
||||||
strHTTP = []byte("http")
|
|
||||||
strHTTPS = []byte("https")
|
|
||||||
defaultUserAgent = "fiber"
|
defaultUserAgent = "fiber"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
//nolint:wrapcheck // We must not wrap errors in tests
|
||||||
package fiber
|
package fiber
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
@ -19,6 +20,7 @@ import (
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2/internal/tlstest"
|
"github.com/gofiber/fiber/v2/internal/tlstest"
|
||||||
"github.com/gofiber/fiber/v2/utils"
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp/fasthttputil"
|
"github.com/valyala/fasthttp/fasthttputil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -295,8 +297,10 @@ func Test_Client_Agent_Set_Or_Add_Headers(t *testing.T) {
|
||||||
handler := func(c *Ctx) error {
|
handler := func(c *Ctx) error {
|
||||||
c.Request().Header.VisitAll(func(key, value []byte) {
|
c.Request().Header.VisitAll(func(key, value []byte) {
|
||||||
if k := string(key); k == "K1" || k == "K2" {
|
if k := string(key); k == "K1" || k == "K2" {
|
||||||
_, _ = c.Write(key)
|
_, err := c.Write(key)
|
||||||
_, _ = c.Write(value)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
_, err = c.Write(value)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
|
@ -581,25 +585,26 @@ type readErrorConn struct {
|
||||||
net.Conn
|
net.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *readErrorConn) Read(p []byte) (int, error) {
|
func (*readErrorConn) Read(_ []byte) (int, error) {
|
||||||
return 0, fmt.Errorf("error")
|
return 0, fmt.Errorf("error")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *readErrorConn) Write(p []byte) (int, error) {
|
func (*readErrorConn) Write(p []byte) (int, error) {
|
||||||
return len(p), nil
|
return len(p), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *readErrorConn) Close() error {
|
func (*readErrorConn) Close() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *readErrorConn) LocalAddr() net.Addr {
|
func (*readErrorConn) LocalAddr() net.Addr {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *readErrorConn) RemoteAddr() net.Addr {
|
func (*readErrorConn) RemoteAddr() net.Addr {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_Client_Agent_RetryIf(t *testing.T) {
|
func Test_Client_Agent_RetryIf(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
@ -783,7 +788,10 @@ func Test_Client_Agent_MultipartForm_SendFiles(t *testing.T) {
|
||||||
buf := make([]byte, fh1.Size)
|
buf := make([]byte, fh1.Size)
|
||||||
f, err := fh1.Open()
|
f, err := fh1.Open()
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
defer func() { _ = f.Close() }()
|
defer func() {
|
||||||
|
err := f.Close()
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
}()
|
||||||
_, err = f.Read(buf)
|
_, err = f.Read(buf)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, "form file", string(buf))
|
utils.AssertEqual(t, "form file", string(buf))
|
||||||
|
@ -831,13 +839,16 @@ func checkFormFile(t *testing.T, fh *multipart.FileHeader, filename string) {
|
||||||
basename := filepath.Base(filename)
|
basename := filepath.Base(filename)
|
||||||
utils.AssertEqual(t, fh.Filename, basename)
|
utils.AssertEqual(t, fh.Filename, basename)
|
||||||
|
|
||||||
b1, err := os.ReadFile(filename)
|
b1, err := os.ReadFile(filename) //nolint:gosec // We're in a test so reading user-provided files by name is fine
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
b2 := make([]byte, fh.Size)
|
b2 := make([]byte, fh.Size)
|
||||||
f, err := fh.Open()
|
f, err := fh.Open()
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
defer func() { _ = f.Close() }()
|
defer func() {
|
||||||
|
err := f.Close()
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
}()
|
||||||
_, err = f.Read(b2)
|
_, err = f.Read(b2)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, b1, b2)
|
utils.AssertEqual(t, b1, b2)
|
||||||
|
@ -962,6 +973,7 @@ func Test_Client_Agent_InsecureSkipVerify(t *testing.T) {
|
||||||
cer, err := tls.LoadX509KeyPair("./.github/testdata/ssl.pem", "./.github/testdata/ssl.key")
|
cer, err := tls.LoadX509KeyPair("./.github/testdata/ssl.pem", "./.github/testdata/ssl.key")
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
|
//nolint:gosec // We're in a test so using old ciphers is fine
|
||||||
serverTLSConf := &tls.Config{
|
serverTLSConf := &tls.Config{
|
||||||
Certificates: []tls.Certificate{cer},
|
Certificates: []tls.Certificate{cer},
|
||||||
}
|
}
|
||||||
|
@ -1137,7 +1149,7 @@ func Test_Client_Agent_Struct(t *testing.T) {
|
||||||
defer ReleaseAgent(a)
|
defer ReleaseAgent(a)
|
||||||
defer a.ConnectionClose()
|
defer a.ConnectionClose()
|
||||||
request := a.Request()
|
request := a.Request()
|
||||||
request.Header.SetMethod("GET")
|
request.Header.SetMethod(MethodGet)
|
||||||
request.SetRequestURI("http://example.com")
|
request.SetRequestURI("http://example.com")
|
||||||
err := a.Parse()
|
err := a.Parse()
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
@ -1198,8 +1210,8 @@ type errorMultipartWriter struct {
|
||||||
count int
|
count int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *errorMultipartWriter) Boundary() string { return "myBoundary" }
|
func (*errorMultipartWriter) Boundary() string { return "myBoundary" }
|
||||||
func (e *errorMultipartWriter) SetBoundary(_ string) error { return nil }
|
func (*errorMultipartWriter) SetBoundary(_ string) error { return nil }
|
||||||
func (e *errorMultipartWriter) CreateFormFile(_, _ string) (io.Writer, error) {
|
func (e *errorMultipartWriter) CreateFormFile(_, _ string) (io.Writer, error) {
|
||||||
if e.count == 0 {
|
if e.count == 0 {
|
||||||
e.count++
|
e.count++
|
||||||
|
@ -1207,8 +1219,8 @@ func (e *errorMultipartWriter) CreateFormFile(_, _ string) (io.Writer, error) {
|
||||||
}
|
}
|
||||||
return errorWriter{}, nil
|
return errorWriter{}, nil
|
||||||
}
|
}
|
||||||
func (e *errorMultipartWriter) WriteField(_, _ string) error { return errors.New("WriteField error") }
|
func (*errorMultipartWriter) WriteField(_, _ string) error { return errors.New("WriteField error") }
|
||||||
func (e *errorMultipartWriter) Close() error { return errors.New("Close error") }
|
func (*errorMultipartWriter) Close() error { return errors.New("Close error") }
|
||||||
|
|
||||||
type errorWriter struct{}
|
type errorWriter struct{}
|
||||||
|
|
||||||
|
|
2
color.go
2
color.go
|
@ -53,6 +53,8 @@ type Colors struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultColors Default color codes
|
// DefaultColors Default color codes
|
||||||
|
//
|
||||||
|
//nolint:gochecknoglobals // Using a global var is fine here
|
||||||
var DefaultColors = Colors{
|
var DefaultColors = Colors{
|
||||||
Black: "\u001b[90m",
|
Black: "\u001b[90m",
|
||||||
Red: "\u001b[91m",
|
Red: "\u001b[91m",
|
||||||
|
|
188
ctx.go
188
ctx.go
|
@ -27,10 +27,16 @@ import (
|
||||||
"github.com/gofiber/fiber/v2/internal/dictpool"
|
"github.com/gofiber/fiber/v2/internal/dictpool"
|
||||||
"github.com/gofiber/fiber/v2/internal/schema"
|
"github.com/gofiber/fiber/v2/internal/schema"
|
||||||
"github.com/gofiber/fiber/v2/utils"
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
|
|
||||||
"github.com/valyala/bytebufferpool"
|
"github.com/valyala/bytebufferpool"
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
schemeHTTP = "http"
|
||||||
|
schemeHTTPS = "https"
|
||||||
|
)
|
||||||
|
|
||||||
// maxParams defines the maximum number of parameters per route.
|
// maxParams defines the maximum number of parameters per route.
|
||||||
const maxParams = 30
|
const maxParams = 30
|
||||||
|
|
||||||
|
@ -45,6 +51,7 @@ const (
|
||||||
// userContextKey define the key name for storing context.Context in *fasthttp.RequestCtx
|
// userContextKey define the key name for storing context.Context in *fasthttp.RequestCtx
|
||||||
const userContextKey = "__local_user_context__"
|
const userContextKey = "__local_user_context__"
|
||||||
|
|
||||||
|
//nolint:gochecknoglobals // TODO: Do not use global vars here
|
||||||
var (
|
var (
|
||||||
// decoderPoolMap helps to improve BodyParser's, QueryParser's and ReqHeaderParser's performance
|
// decoderPoolMap helps to improve BodyParser's, QueryParser's and ReqHeaderParser's performance
|
||||||
decoderPoolMap = map[string]*sync.Pool{}
|
decoderPoolMap = map[string]*sync.Pool{}
|
||||||
|
@ -52,6 +59,7 @@ var (
|
||||||
tags = []string{queryTag, bodyTag, reqHeaderTag, paramsTag}
|
tags = []string{queryTag, bodyTag, reqHeaderTag, paramsTag}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
//nolint:gochecknoinits // init() is used to initialize a global map variable
|
||||||
func init() {
|
func init() {
|
||||||
for _, tag := range tags {
|
for _, tag := range tags {
|
||||||
decoderPoolMap[tag] = &sync.Pool{New: func() interface{} {
|
decoderPoolMap[tag] = &sync.Pool{New: func() interface{} {
|
||||||
|
@ -100,9 +108,10 @@ type TLSHandler struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetClientInfo Callback function to set CHI
|
// GetClientInfo Callback function to set CHI
|
||||||
|
// TODO: Why is this a getter which sets stuff?
|
||||||
func (t *TLSHandler) GetClientInfo(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
func (t *TLSHandler) GetClientInfo(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
t.clientHelloInfo = info
|
t.clientHelloInfo = info
|
||||||
return nil, nil
|
return nil, nil //nolint:nilnil // Not returning anything useful here is probably fine
|
||||||
}
|
}
|
||||||
|
|
||||||
// Range data for c.Range
|
// Range data for c.Range
|
||||||
|
@ -151,7 +160,10 @@ type ParserConfig struct {
|
||||||
|
|
||||||
// AcquireCtx retrieves a new Ctx from the pool.
|
// AcquireCtx retrieves a new Ctx from the pool.
|
||||||
func (app *App) AcquireCtx(fctx *fasthttp.RequestCtx) *Ctx {
|
func (app *App) AcquireCtx(fctx *fasthttp.RequestCtx) *Ctx {
|
||||||
c := app.pool.Get().(*Ctx)
|
c, ok := app.pool.Get().(*Ctx)
|
||||||
|
if !ok {
|
||||||
|
panic(fmt.Errorf("failed to type-assert to *Ctx"))
|
||||||
|
}
|
||||||
// Set app reference
|
// Set app reference
|
||||||
c.app = app
|
c.app = app
|
||||||
// Reset route and handler index
|
// Reset route and handler index
|
||||||
|
@ -388,7 +400,6 @@ func (c *Ctx) BodyParser(out interface{}) error {
|
||||||
} else {
|
} else {
|
||||||
data[k] = append(data[k], v)
|
data[k] = append(data[k], v)
|
||||||
}
|
}
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
return c.parseToStruct(bodyTag, out, data)
|
return c.parseToStruct(bodyTag, out, data)
|
||||||
|
@ -401,7 +412,10 @@ func (c *Ctx) BodyParser(out interface{}) error {
|
||||||
return c.parseToStruct(bodyTag, out, data.Value)
|
return c.parseToStruct(bodyTag, out, data.Value)
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(ctype, MIMETextXML) || strings.HasPrefix(ctype, MIMEApplicationXML) {
|
if strings.HasPrefix(ctype, MIMETextXML) || strings.HasPrefix(ctype, MIMEApplicationXML) {
|
||||||
return xml.Unmarshal(c.Body(), out)
|
if err := xml.Unmarshal(c.Body(), out); err != nil {
|
||||||
|
return fmt.Errorf("failed to unmarshal: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
// No suitable content type found
|
// No suitable content type found
|
||||||
return ErrUnprocessableEntity
|
return ErrUnprocessableEntity
|
||||||
|
@ -673,8 +687,11 @@ func (c *Ctx) Hostname() string {
|
||||||
|
|
||||||
// Port returns the remote port of the request.
|
// Port returns the remote port of the request.
|
||||||
func (c *Ctx) Port() string {
|
func (c *Ctx) Port() string {
|
||||||
port := c.fasthttp.RemoteAddr().(*net.TCPAddr).Port
|
tcpaddr, ok := c.fasthttp.RemoteAddr().(*net.TCPAddr)
|
||||||
return strconv.Itoa(port)
|
if !ok {
|
||||||
|
panic(fmt.Errorf("failed to type-assert to *net.TCPAddr"))
|
||||||
|
}
|
||||||
|
return strconv.Itoa(tcpaddr.Port)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IP returns the remote IP address of the request.
|
// IP returns the remote IP address of the request.
|
||||||
|
@ -691,13 +708,16 @@ func (c *Ctx) IP() string {
|
||||||
// extractIPsFromHeader will return a slice of IPs it found given a header name in the order they appear.
|
// extractIPsFromHeader will return a slice of IPs it found given a header name in the order they appear.
|
||||||
// When IP validation is enabled, any invalid IPs will be omitted.
|
// When IP validation is enabled, any invalid IPs will be omitted.
|
||||||
func (c *Ctx) extractIPsFromHeader(header string) []string {
|
func (c *Ctx) extractIPsFromHeader(header string) []string {
|
||||||
|
// TODO: Reuse the c.extractIPFromHeader func somehow in here
|
||||||
|
|
||||||
headerValue := c.Get(header)
|
headerValue := c.Get(header)
|
||||||
|
|
||||||
// We can't know how many IPs we will return, but we will try to guess with this constant division.
|
// We can't know how many IPs we will return, but we will try to guess with this constant division.
|
||||||
// Counting ',' makes function slower for about 50ns in general case.
|
// Counting ',' makes function slower for about 50ns in general case.
|
||||||
estimatedCount := len(headerValue) / 8
|
const maxEstimatedCount = 8
|
||||||
if estimatedCount > 8 {
|
estimatedCount := len(headerValue) / maxEstimatedCount
|
||||||
estimatedCount = 8 // Avoid big allocation on big header
|
if estimatedCount > maxEstimatedCount {
|
||||||
|
estimatedCount = maxEstimatedCount // Avoid big allocation on big header
|
||||||
}
|
}
|
||||||
|
|
||||||
ipsFound := make([]string, 0, estimatedCount)
|
ipsFound := make([]string, 0, estimatedCount)
|
||||||
|
@ -707,11 +727,10 @@ func (c *Ctx) extractIPsFromHeader(header string) []string {
|
||||||
|
|
||||||
iploop:
|
iploop:
|
||||||
for {
|
for {
|
||||||
v4 := false
|
var v4, v6 bool
|
||||||
v6 := false
|
|
||||||
|
|
||||||
// Manually splitting string without allocating slice, working with parts directly
|
// Manually splitting string without allocating slice, working with parts directly
|
||||||
i, j = j+1, j+2
|
i, j = j+1, j+2 //nolint:gomnd // Using these values is fine
|
||||||
|
|
||||||
if j > len(headerValue) {
|
if j > len(headerValue) {
|
||||||
break
|
break
|
||||||
|
@ -758,9 +777,10 @@ func (c *Ctx) extractIPFromHeader(header string) string {
|
||||||
|
|
||||||
iploop:
|
iploop:
|
||||||
for {
|
for {
|
||||||
v4 := false
|
var v4, v6 bool
|
||||||
v6 := false
|
|
||||||
i, j = j+1, j+2
|
// Manually splitting string without allocating slice, working with parts directly
|
||||||
|
i, j = j+1, j+2 //nolint:gomnd // Using these values is fine
|
||||||
|
|
||||||
if j > len(headerValue) {
|
if j > len(headerValue) {
|
||||||
break
|
break
|
||||||
|
@ -793,14 +813,14 @@ func (c *Ctx) extractIPFromHeader(header string) string {
|
||||||
return c.fasthttp.RemoteIP().String()
|
return c.fasthttp.RemoteIP().String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// default behaviour if IP validation is not enabled is just to return whatever value is
|
// default behavior if IP validation is not enabled is just to return whatever value is
|
||||||
// in the proxy header. Even if it is empty or invalid
|
// in the proxy header. Even if it is empty or invalid
|
||||||
return c.Get(c.app.config.ProxyHeader)
|
return c.Get(c.app.config.ProxyHeader)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IPs returns a string slice of IP addresses specified in the X-Forwarded-For request header.
|
// IPs returns a string slice of IP addresses specified in the X-Forwarded-For request header.
|
||||||
// When IP validation is enabled, only valid IPs are returned.
|
// When IP validation is enabled, only valid IPs are returned.
|
||||||
func (c *Ctx) IPs() (ips []string) {
|
func (c *Ctx) IPs() []string {
|
||||||
return c.extractIPsFromHeader(HeaderXForwardedFor)
|
return c.extractIPsFromHeader(HeaderXForwardedFor)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -839,7 +859,7 @@ func (c *Ctx) JSON(data interface{}) error {
|
||||||
func (c *Ctx) JSONP(data interface{}, callback ...string) error {
|
func (c *Ctx) JSONP(data interface{}, callback ...string) error {
|
||||||
raw, err := json.Marshal(data)
|
raw, err := json.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to marshal: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var result, cb string
|
var result, cb string
|
||||||
|
@ -877,11 +897,11 @@ func (c *Ctx) Links(link ...string) {
|
||||||
bb := bytebufferpool.Get()
|
bb := bytebufferpool.Get()
|
||||||
for i := range link {
|
for i := range link {
|
||||||
if i%2 == 0 {
|
if i%2 == 0 {
|
||||||
_ = bb.WriteByte('<')
|
_ = bb.WriteByte('<') //nolint:errcheck // This will never fail
|
||||||
_, _ = bb.WriteString(link[i])
|
_, _ = bb.WriteString(link[i]) //nolint:errcheck // This will never fail
|
||||||
_ = bb.WriteByte('>')
|
_ = bb.WriteByte('>') //nolint:errcheck // This will never fail
|
||||||
} else {
|
} else {
|
||||||
_, _ = bb.WriteString(`; rel="` + link[i] + `",`)
|
_, _ = bb.WriteString(`; rel="` + link[i] + `",`) //nolint:errcheck // This will never fail
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.setCanonical(HeaderLink, utils.TrimRight(c.app.getString(bb.Bytes()), ','))
|
c.setCanonical(HeaderLink, utils.TrimRight(c.app.getString(bb.Bytes()), ','))
|
||||||
|
@ -890,7 +910,7 @@ func (c *Ctx) Links(link ...string) {
|
||||||
|
|
||||||
// Locals makes it possible to pass interface{} values under keys scoped to the request
|
// Locals makes it possible to pass interface{} values under keys scoped to the request
|
||||||
// and therefore available to all following routes that match the request.
|
// and therefore available to all following routes that match the request.
|
||||||
func (c *Ctx) Locals(key interface{}, value ...interface{}) (val interface{}) {
|
func (c *Ctx) Locals(key interface{}, value ...interface{}) interface{} {
|
||||||
if len(value) == 0 {
|
if len(value) == 0 {
|
||||||
return c.fasthttp.UserValue(key)
|
return c.fasthttp.UserValue(key)
|
||||||
}
|
}
|
||||||
|
@ -933,9 +953,10 @@ func (c *Ctx) ClientHelloInfo() *tls.ClientHelloInfo {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Next executes the next method in the stack that matches the current route.
|
// Next executes the next method in the stack that matches the current route.
|
||||||
func (c *Ctx) Next() (err error) {
|
func (c *Ctx) Next() error {
|
||||||
// Increment handler index
|
// Increment handler index
|
||||||
c.indexHandler++
|
c.indexHandler++
|
||||||
|
var err error
|
||||||
// Did we executed all route handlers?
|
// Did we executed all route handlers?
|
||||||
if c.indexHandler < len(c.route.Handlers) {
|
if c.indexHandler < len(c.route.Handlers) {
|
||||||
// Continue route stack
|
// Continue route stack
|
||||||
|
@ -947,7 +968,7 @@ func (c *Ctx) Next() (err error) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// RestartRouting instead of going to the next handler. This may be usefull after
|
// RestartRouting instead of going to the next handler. This may be useful after
|
||||||
// changing the request path. Note that handlers might be executed again.
|
// changing the request path. Note that handlers might be executed again.
|
||||||
func (c *Ctx) RestartRouting() error {
|
func (c *Ctx) RestartRouting() error {
|
||||||
c.indexRoute = -1
|
c.indexRoute = -1
|
||||||
|
@ -1017,9 +1038,8 @@ func (c *Ctx) ParamsInt(key string, defaultValue ...int) (int, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if len(defaultValue) > 0 {
|
if len(defaultValue) > 0 {
|
||||||
return defaultValue[0], nil
|
return defaultValue[0], nil
|
||||||
} else {
|
|
||||||
return 0, err
|
|
||||||
}
|
}
|
||||||
|
return 0, fmt.Errorf("failed to convert: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return value, nil
|
return value, nil
|
||||||
|
@ -1044,15 +1064,16 @@ func (c *Ctx) Path(override ...string) string {
|
||||||
// Please use Config.EnableTrustedProxyCheck to prevent header spoofing, in case when your app is behind the proxy.
|
// Please use Config.EnableTrustedProxyCheck to prevent header spoofing, in case when your app is behind the proxy.
|
||||||
func (c *Ctx) Protocol() string {
|
func (c *Ctx) Protocol() string {
|
||||||
if c.fasthttp.IsTLS() {
|
if c.fasthttp.IsTLS() {
|
||||||
return "https"
|
return schemeHTTPS
|
||||||
}
|
}
|
||||||
if !c.IsProxyTrusted() {
|
if !c.IsProxyTrusted() {
|
||||||
return "http"
|
return schemeHTTP
|
||||||
}
|
}
|
||||||
|
|
||||||
scheme := "http"
|
scheme := schemeHTTP
|
||||||
|
const lenXHeaderName = 12
|
||||||
c.fasthttp.Request.Header.VisitAll(func(key, val []byte) {
|
c.fasthttp.Request.Header.VisitAll(func(key, val []byte) {
|
||||||
if len(key) < 12 {
|
if len(key) < lenXHeaderName {
|
||||||
return // Neither "X-Forwarded-" nor "X-Url-Scheme"
|
return // Neither "X-Forwarded-" nor "X-Url-Scheme"
|
||||||
}
|
}
|
||||||
switch {
|
switch {
|
||||||
|
@ -1067,7 +1088,7 @@ func (c *Ctx) Protocol() string {
|
||||||
scheme = v
|
scheme = v
|
||||||
}
|
}
|
||||||
} else if bytes.Equal(key, []byte(HeaderXForwardedSsl)) && bytes.Equal(val, []byte("on")) {
|
} else if bytes.Equal(key, []byte(HeaderXForwardedSsl)) && bytes.Equal(val, []byte("on")) {
|
||||||
scheme = "https"
|
scheme = schemeHTTPS
|
||||||
}
|
}
|
||||||
|
|
||||||
case bytes.Equal(key, []byte(HeaderXUrlScheme)):
|
case bytes.Equal(key, []byte(HeaderXUrlScheme)):
|
||||||
|
@ -1100,9 +1121,8 @@ func (c *Ctx) QueryInt(key string, defaultValue ...int) int {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if len(defaultValue) > 0 {
|
if len(defaultValue) > 0 {
|
||||||
return defaultValue[0]
|
return defaultValue[0]
|
||||||
} else {
|
|
||||||
return 0
|
|
||||||
}
|
}
|
||||||
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
@ -1133,7 +1153,6 @@ func (c *Ctx) QueryParser(out interface{}) error {
|
||||||
} else {
|
} else {
|
||||||
data[k] = append(data[k], v)
|
data[k] = append(data[k], v)
|
||||||
}
|
}
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1150,10 +1169,9 @@ func parseParamSquareBrackets(k string) (string, error) {
|
||||||
kbytes := []byte(k)
|
kbytes := []byte(k)
|
||||||
|
|
||||||
for i, b := range kbytes {
|
for i, b := range kbytes {
|
||||||
|
|
||||||
if b == '[' && kbytes[i+1] != ']' {
|
if b == '[' && kbytes[i+1] != ']' {
|
||||||
if err := bb.WriteByte('.'); err != nil {
|
if err := bb.WriteByte('.'); err != nil {
|
||||||
return "", err
|
return "", fmt.Errorf("failed to write: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1162,7 +1180,7 @@ func parseParamSquareBrackets(k string) (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := bb.WriteByte(b); err != nil {
|
if err := bb.WriteByte(b); err != nil {
|
||||||
return "", err
|
return "", fmt.Errorf("failed to write: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1184,21 +1202,27 @@ func (c *Ctx) ReqHeaderParser(out interface{}) error {
|
||||||
} else {
|
} else {
|
||||||
data[k] = append(data[k], v)
|
data[k] = append(data[k], v)
|
||||||
}
|
}
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
return c.parseToStruct(reqHeaderTag, out, data)
|
return c.parseToStruct(reqHeaderTag, out, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Ctx) parseToStruct(aliasTag string, out interface{}, data map[string][]string) error {
|
func (*Ctx) parseToStruct(aliasTag string, out interface{}, data map[string][]string) error {
|
||||||
// Get decoder from pool
|
// Get decoder from pool
|
||||||
schemaDecoder := decoderPoolMap[aliasTag].Get().(*schema.Decoder)
|
schemaDecoder, ok := decoderPoolMap[aliasTag].Get().(*schema.Decoder)
|
||||||
|
if !ok {
|
||||||
|
panic(fmt.Errorf("failed to type-assert to *schema.Decoder"))
|
||||||
|
}
|
||||||
defer decoderPoolMap[aliasTag].Put(schemaDecoder)
|
defer decoderPoolMap[aliasTag].Put(schemaDecoder)
|
||||||
|
|
||||||
// Set alias tag
|
// Set alias tag
|
||||||
schemaDecoder.SetAliasTag(aliasTag)
|
schemaDecoder.SetAliasTag(aliasTag)
|
||||||
|
|
||||||
return schemaDecoder.Decode(out, data)
|
if err := schemaDecoder.Decode(out, data); err != nil {
|
||||||
|
return fmt.Errorf("failed to decode: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func equalFieldType(out interface{}, kind reflect.Kind, key string) bool {
|
func equalFieldType(out interface{}, kind reflect.Kind, key string) bool {
|
||||||
|
@ -1248,24 +1272,23 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
// Range returns a struct containing the type and a slice of ranges.
|
// Range returns a struct containing the type and a slice of ranges.
|
||||||
func (c *Ctx) Range(size int) (rangeData Range, err error) {
|
func (c *Ctx) Range(size int) (Range, error) {
|
||||||
|
var rangeData Range
|
||||||
rangeStr := c.Get(HeaderRange)
|
rangeStr := c.Get(HeaderRange)
|
||||||
if rangeStr == "" || !strings.Contains(rangeStr, "=") {
|
if rangeStr == "" || !strings.Contains(rangeStr, "=") {
|
||||||
err = ErrRangeMalformed
|
return rangeData, ErrRangeMalformed
|
||||||
return
|
|
||||||
}
|
}
|
||||||
data := strings.Split(rangeStr, "=")
|
data := strings.Split(rangeStr, "=")
|
||||||
if len(data) != 2 {
|
const expectedDataParts = 2
|
||||||
err = ErrRangeMalformed
|
if len(data) != expectedDataParts {
|
||||||
return
|
return rangeData, ErrRangeMalformed
|
||||||
}
|
}
|
||||||
rangeData.Type = data[0]
|
rangeData.Type = data[0]
|
||||||
arr := strings.Split(data[1], ",")
|
arr := strings.Split(data[1], ",")
|
||||||
for i := 0; i < len(arr); i++ {
|
for i := 0; i < len(arr); i++ {
|
||||||
item := strings.Split(arr[i], "-")
|
item := strings.Split(arr[i], "-")
|
||||||
if len(item) == 1 {
|
if len(item) == 1 {
|
||||||
err = ErrRangeMalformed
|
return rangeData, ErrRangeMalformed
|
||||||
return
|
|
||||||
}
|
}
|
||||||
start, startErr := strconv.Atoi(item[0])
|
start, startErr := strconv.Atoi(item[0])
|
||||||
end, endErr := strconv.Atoi(item[1])
|
end, endErr := strconv.Atoi(item[1])
|
||||||
|
@ -1290,11 +1313,10 @@ func (c *Ctx) Range(size int) (rangeData Range, err error) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if len(rangeData.Ranges) < 1 {
|
if len(rangeData.Ranges) < 1 {
|
||||||
err = ErrRangeUnsatisfiable
|
return rangeData, ErrRangeUnsatisfiable
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return rangeData, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Redirect to the URL derived from the specified path, with specified status.
|
// Redirect to the URL derived from the specified path, with specified status.
|
||||||
|
@ -1330,7 +1352,7 @@ func (c *Ctx) getLocationFromRoute(route Route, params Map) (string, error) {
|
||||||
if !segment.IsParam {
|
if !segment.IsParam {
|
||||||
_, err := buf.WriteString(segment.Const)
|
_, err := buf.WriteString(segment.Const)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", fmt.Errorf("failed to write string: %w", err)
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -1341,7 +1363,7 @@ func (c *Ctx) getLocationFromRoute(route Route, params Map) (string, error) {
|
||||||
if isSame || isGreedy {
|
if isSame || isGreedy {
|
||||||
_, err := buf.WriteString(utils.ToString(val))
|
_, err := buf.WriteString(utils.ToString(val))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", fmt.Errorf("failed to write string: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1373,10 +1395,10 @@ func (c *Ctx) RedirectToRoute(routeName string, params Map, status ...int) error
|
||||||
|
|
||||||
i := 1
|
i := 1
|
||||||
for k, v := range queries {
|
for k, v := range queries {
|
||||||
_, _ = queryText.WriteString(k + "=" + v)
|
_, _ = queryText.WriteString(k + "=" + v) //nolint:errcheck // This will never fail
|
||||||
|
|
||||||
if i != len(queries) {
|
if i != len(queries) {
|
||||||
_, _ = queryText.WriteString("&")
|
_, _ = queryText.WriteString("&") //nolint:errcheck // This will never fail
|
||||||
}
|
}
|
||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
|
@ -1399,7 +1421,6 @@ func (c *Ctx) RedirectBack(fallback string, status ...int) error {
|
||||||
// Render a template with data and sends a text/html response.
|
// Render a template with data and sends a text/html response.
|
||||||
// We support the following engines: html, amber, handlebars, mustache, pug
|
// We support the following engines: html, amber, handlebars, mustache, pug
|
||||||
func (c *Ctx) Render(name string, bind interface{}, layouts ...string) error {
|
func (c *Ctx) Render(name string, bind interface{}, layouts ...string) error {
|
||||||
var err error
|
|
||||||
// Get new buffer from pool
|
// Get new buffer from pool
|
||||||
buf := bytebufferpool.Get()
|
buf := bytebufferpool.Get()
|
||||||
defer bytebufferpool.Put(buf)
|
defer bytebufferpool.Put(buf)
|
||||||
|
@ -1421,7 +1442,7 @@ func (c *Ctx) Render(name string, bind interface{}, layouts ...string) error {
|
||||||
// Render template from Views
|
// Render template from Views
|
||||||
if app.config.Views != nil {
|
if app.config.Views != nil {
|
||||||
if err := app.config.Views.Render(buf, name, bind, layouts...); err != nil {
|
if err := app.config.Views.Render(buf, name, bind, layouts...); err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to render: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
rendered = true
|
rendered = true
|
||||||
|
@ -1433,17 +1454,18 @@ func (c *Ctx) Render(name string, bind interface{}, layouts ...string) error {
|
||||||
if !rendered {
|
if !rendered {
|
||||||
// Render raw template using 'name' as filepath if no engine is set
|
// Render raw template using 'name' as filepath if no engine is set
|
||||||
var tmpl *template.Template
|
var tmpl *template.Template
|
||||||
if _, err = readContent(buf, name); err != nil {
|
if _, err := readContent(buf, name); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// Parse template
|
// Parse template
|
||||||
if tmpl, err = template.New("").Parse(c.app.getString(buf.Bytes())); err != nil {
|
tmpl, err := template.New("").Parse(c.app.getString(buf.Bytes()))
|
||||||
return err
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse: %w", err)
|
||||||
}
|
}
|
||||||
buf.Reset()
|
buf.Reset()
|
||||||
// Render template
|
// Render template
|
||||||
if err = tmpl.Execute(buf, bind); err != nil {
|
if err := tmpl.Execute(buf, bind); err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to execute: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1451,8 +1473,8 @@ func (c *Ctx) Render(name string, bind interface{}, layouts ...string) error {
|
||||||
c.fasthttp.Response.Header.SetContentType(MIMETextHTMLCharsetUTF8)
|
c.fasthttp.Response.Header.SetContentType(MIMETextHTMLCharsetUTF8)
|
||||||
// Set rendered template to body
|
// Set rendered template to body
|
||||||
c.fasthttp.Response.SetBody(buf.Bytes())
|
c.fasthttp.Response.SetBody(buf.Bytes())
|
||||||
// Return err if exist
|
|
||||||
return err
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Ctx) renderExtensions(bind interface{}) {
|
func (c *Ctx) renderExtensions(bind interface{}) {
|
||||||
|
@ -1501,28 +1523,32 @@ func (c *Ctx) Route() *Route {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveFile saves any multipart file to disk.
|
// SaveFile saves any multipart file to disk.
|
||||||
func (c *Ctx) SaveFile(fileheader *multipart.FileHeader, path string) error {
|
func (*Ctx) SaveFile(fileheader *multipart.FileHeader, path string) error {
|
||||||
return fasthttp.SaveMultipartFile(fileheader, path)
|
return fasthttp.SaveMultipartFile(fileheader, path)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveFileToStorage saves any multipart file to an external storage system.
|
// SaveFileToStorage saves any multipart file to an external storage system.
|
||||||
func (c *Ctx) SaveFileToStorage(fileheader *multipart.FileHeader, path string, storage Storage) error {
|
func (*Ctx) SaveFileToStorage(fileheader *multipart.FileHeader, path string, storage Storage) error {
|
||||||
file, err := fileheader.Open()
|
file, err := fileheader.Open()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to open: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
content, err := io.ReadAll(file)
|
content, err := io.ReadAll(file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to read: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return storage.Set(path, content, 0)
|
if err := storage.Set(path, content, 0); err != nil {
|
||||||
|
return fmt.Errorf("failed to store: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Secure returns whether a secure connection was established.
|
// Secure returns whether a secure connection was established.
|
||||||
func (c *Ctx) Secure() bool {
|
func (c *Ctx) Secure() bool {
|
||||||
return c.Protocol() == "https"
|
return c.Protocol() == schemeHTTPS
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send sets the HTTP response body without copying it.
|
// Send sets the HTTP response body without copying it.
|
||||||
|
@ -1533,6 +1559,7 @@ func (c *Ctx) Send(body []byte) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//nolint:gochecknoglobals // TODO: Do not use global vars here
|
||||||
var (
|
var (
|
||||||
sendFileOnce sync.Once
|
sendFileOnce sync.Once
|
||||||
sendFileFS *fasthttp.FS
|
sendFileFS *fasthttp.FS
|
||||||
|
@ -1548,6 +1575,7 @@ func (c *Ctx) SendFile(file string, compress ...bool) error {
|
||||||
|
|
||||||
// https://github.com/valyala/fasthttp/blob/c7576cc10cabfc9c993317a2d3f8355497bea156/fs.go#L129-L134
|
// https://github.com/valyala/fasthttp/blob/c7576cc10cabfc9c993317a2d3f8355497bea156/fs.go#L129-L134
|
||||||
sendFileOnce.Do(func() {
|
sendFileOnce.Do(func() {
|
||||||
|
const cacheDuration = 10 * time.Second
|
||||||
sendFileFS = &fasthttp.FS{
|
sendFileFS = &fasthttp.FS{
|
||||||
Root: "",
|
Root: "",
|
||||||
AllowEmptyRoot: true,
|
AllowEmptyRoot: true,
|
||||||
|
@ -1555,7 +1583,7 @@ func (c *Ctx) SendFile(file string, compress ...bool) error {
|
||||||
AcceptByteRange: true,
|
AcceptByteRange: true,
|
||||||
Compress: true,
|
Compress: true,
|
||||||
CompressedFileSuffix: c.app.config.CompressedFileSuffix,
|
CompressedFileSuffix: c.app.config.CompressedFileSuffix,
|
||||||
CacheDuration: 10 * time.Second,
|
CacheDuration: cacheDuration,
|
||||||
IndexNames: []string{"index.html"},
|
IndexNames: []string{"index.html"},
|
||||||
PathNotFound: func(ctx *fasthttp.RequestCtx) {
|
PathNotFound: func(ctx *fasthttp.RequestCtx) {
|
||||||
ctx.Response.SetStatusCode(StatusNotFound)
|
ctx.Response.SetStatusCode(StatusNotFound)
|
||||||
|
@ -1579,7 +1607,7 @@ func (c *Ctx) SendFile(file string, compress ...bool) error {
|
||||||
var err error
|
var err error
|
||||||
file = filepath.FromSlash(file)
|
file = filepath.FromSlash(file)
|
||||||
if file, err = filepath.Abs(file); err != nil {
|
if file, err = filepath.Abs(file); err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to determine abs file path: %w", err)
|
||||||
}
|
}
|
||||||
if hasTrailingSlash {
|
if hasTrailingSlash {
|
||||||
file += "/"
|
file += "/"
|
||||||
|
@ -1644,11 +1672,11 @@ func (c *Ctx) SendStream(stream io.Reader, size ...int) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set sets the response's HTTP header field to the specified key, value.
|
// Set sets the response's HTTP header field to the specified key, value.
|
||||||
func (c *Ctx) Set(key string, val string) {
|
func (c *Ctx) Set(key, val string) {
|
||||||
c.fasthttp.Response.Header.Set(key, val)
|
c.fasthttp.Response.Header.Set(key, val)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Ctx) setCanonical(key string, val string) {
|
func (c *Ctx) setCanonical(key, val string) {
|
||||||
c.fasthttp.Response.Header.SetCanonical(c.app.getBytes(key), c.app.getBytes(val))
|
c.fasthttp.Response.Header.SetCanonical(c.app.getBytes(key), c.app.getBytes(val))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1719,6 +1747,7 @@ func (c *Ctx) Write(p []byte) (int, error) {
|
||||||
|
|
||||||
// Writef appends f & a into response body writer.
|
// Writef appends f & a into response body writer.
|
||||||
func (c *Ctx) Writef(f string, a ...interface{}) (int, error) {
|
func (c *Ctx) Writef(f string, a ...interface{}) (int, error) {
|
||||||
|
//nolint:wrapcheck // This must not be wrapped
|
||||||
return fmt.Fprintf(c.fasthttp.Response.BodyWriter(), f, a...)
|
return fmt.Fprintf(c.fasthttp.Response.BodyWriter(), f, a...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1760,8 +1789,9 @@ func (c *Ctx) configDependentPaths() {
|
||||||
// Define the path for dividing routes into areas for fast tree detection, so that fewer routes need to be traversed,
|
// Define the path for dividing routes into areas for fast tree detection, so that fewer routes need to be traversed,
|
||||||
// since the first three characters area select a list of routes
|
// since the first three characters area select a list of routes
|
||||||
c.treePath = c.treePath[0:0]
|
c.treePath = c.treePath[0:0]
|
||||||
if len(c.detectionPath) >= 3 {
|
const maxDetectionPaths = 3
|
||||||
c.treePath = c.detectionPath[:3]
|
if len(c.detectionPath) >= maxDetectionPaths {
|
||||||
|
c.treePath = c.detectionPath[:maxDetectionPaths]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1786,7 +1816,7 @@ func (c *Ctx) IsProxyTrusted() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsLocalHost will return true if address is a localhost address.
|
// IsLocalHost will return true if address is a localhost address.
|
||||||
func (c *Ctx) isLocalHost(address string) bool {
|
func (*Ctx) isLocalHost(address string) bool {
|
||||||
localHosts := []string{"127.0.0.1", "0.0.0.0", "::1"}
|
localHosts := []string{"127.0.0.1", "0.0.0.0", "::1"}
|
||||||
for _, h := range localHosts {
|
for _, h := range localHosts {
|
||||||
if strings.Contains(address, h) {
|
if strings.Contains(address, h) {
|
||||||
|
|
310
ctx_test.go
310
ctx_test.go
|
@ -2,6 +2,7 @@
|
||||||
// 🤖 Github Repository: https://github.com/gofiber/fiber
|
// 🤖 Github Repository: https://github.com/gofiber/fiber
|
||||||
// 📌 API Documentation: https://docs.gofiber.io
|
// 📌 API Documentation: https://docs.gofiber.io
|
||||||
|
|
||||||
|
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
|
||||||
package fiber
|
package fiber
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
@ -28,6 +29,7 @@ import (
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
||||||
"github.com/gofiber/fiber/v2/utils"
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
|
|
||||||
"github.com/valyala/bytebufferpool"
|
"github.com/valyala/bytebufferpool"
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
@ -361,8 +363,10 @@ func Test_Ctx_BodyParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
var gzipJSON bytes.Buffer
|
var gzipJSON bytes.Buffer
|
||||||
w := gzip.NewWriter(&gzipJSON)
|
w := gzip.NewWriter(&gzipJSON)
|
||||||
_, _ = w.Write([]byte(`{"name":"john"}`))
|
_, err := w.Write([]byte(`{"name":"john"}`))
|
||||||
_ = w.Close()
|
utils.AssertEqual(t, nil, err)
|
||||||
|
err = w.Close()
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
c.Request().Header.SetContentType(MIMEApplicationJSON)
|
c.Request().Header.SetContentType(MIMEApplicationJSON)
|
||||||
c.Request().Header.Set(HeaderContentEncoding, "gzip")
|
c.Request().Header.Set(HeaderContentEncoding, "gzip")
|
||||||
|
@ -431,9 +435,7 @@ func Test_Ctx_ParamParser(t *testing.T) {
|
||||||
UserID uint `params:"userId"`
|
UserID uint `params:"userId"`
|
||||||
RoleID uint `params:"roleId"`
|
RoleID uint `params:"roleId"`
|
||||||
}
|
}
|
||||||
var (
|
d := new(Demo)
|
||||||
d = new(Demo)
|
|
||||||
)
|
|
||||||
if err := ctx.ParamsParser(d); err != nil {
|
if err := ctx.ParamsParser(d); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -519,7 +521,7 @@ func Benchmark_Ctx_BodyParser_JSON(b *testing.B) {
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
_ = c.BodyParser(d)
|
_ = c.BodyParser(d) //nolint:errcheck // It is fine to ignore the error here as we check it once further below
|
||||||
}
|
}
|
||||||
utils.AssertEqual(b, nil, c.BodyParser(d))
|
utils.AssertEqual(b, nil, c.BodyParser(d))
|
||||||
utils.AssertEqual(b, "john", d.Name)
|
utils.AssertEqual(b, "john", d.Name)
|
||||||
|
@ -543,7 +545,7 @@ func Benchmark_Ctx_BodyParser_XML(b *testing.B) {
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
_ = c.BodyParser(d)
|
_ = c.BodyParser(d) //nolint:errcheck // It is fine to ignore the error here as we check it once further below
|
||||||
}
|
}
|
||||||
utils.AssertEqual(b, nil, c.BodyParser(d))
|
utils.AssertEqual(b, nil, c.BodyParser(d))
|
||||||
utils.AssertEqual(b, "john", d.Name)
|
utils.AssertEqual(b, "john", d.Name)
|
||||||
|
@ -567,7 +569,7 @@ func Benchmark_Ctx_BodyParser_Form(b *testing.B) {
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
_ = c.BodyParser(d)
|
_ = c.BodyParser(d) //nolint:errcheck // It is fine to ignore the error here as we check it once further below
|
||||||
}
|
}
|
||||||
utils.AssertEqual(b, nil, c.BodyParser(d))
|
utils.AssertEqual(b, nil, c.BodyParser(d))
|
||||||
utils.AssertEqual(b, "john", d.Name)
|
utils.AssertEqual(b, "john", d.Name)
|
||||||
|
@ -592,7 +594,7 @@ func Benchmark_Ctx_BodyParser_MultipartForm(b *testing.B) {
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
_ = c.BodyParser(d)
|
_ = c.BodyParser(d) //nolint:errcheck // It is fine to ignore the error here as we check it once further below
|
||||||
}
|
}
|
||||||
utils.AssertEqual(b, nil, c.BodyParser(d))
|
utils.AssertEqual(b, nil, c.BodyParser(d))
|
||||||
utils.AssertEqual(b, "john", d.Name)
|
utils.AssertEqual(b, "john", d.Name)
|
||||||
|
@ -879,12 +881,13 @@ func Test_Ctx_FormFile(t *testing.T) {
|
||||||
|
|
||||||
f, err := fh.Open()
|
f, err := fh.Open()
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
defer func() {
|
||||||
|
utils.AssertEqual(t, nil, f.Close())
|
||||||
|
}()
|
||||||
|
|
||||||
b := new(bytes.Buffer)
|
b := new(bytes.Buffer)
|
||||||
_, err = io.Copy(b, f)
|
_, err = io.Copy(b, f)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
f.Close()
|
|
||||||
utils.AssertEqual(t, "hello world", b.String())
|
utils.AssertEqual(t, "hello world", b.String())
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
@ -897,8 +900,7 @@ func Test_Ctx_FormFile(t *testing.T) {
|
||||||
|
|
||||||
_, err = ioWriter.Write([]byte("hello world"))
|
_, err = ioWriter.Write([]byte("hello world"))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
utils.AssertEqual(t, nil, writer.Close())
|
||||||
writer.Close()
|
|
||||||
|
|
||||||
req := httptest.NewRequest(MethodPost, "/test", body)
|
req := httptest.NewRequest(MethodPost, "/test", body)
|
||||||
req.Header.Set(HeaderContentType, writer.FormDataContentType())
|
req.Header.Set(HeaderContentType, writer.FormDataContentType())
|
||||||
|
@ -921,10 +923,9 @@ func Test_Ctx_FormValue(t *testing.T) {
|
||||||
|
|
||||||
body := &bytes.Buffer{}
|
body := &bytes.Buffer{}
|
||||||
writer := multipart.NewWriter(body)
|
writer := multipart.NewWriter(body)
|
||||||
|
|
||||||
utils.AssertEqual(t, nil, writer.WriteField("name", "john"))
|
utils.AssertEqual(t, nil, writer.WriteField("name", "john"))
|
||||||
|
utils.AssertEqual(t, nil, writer.Close())
|
||||||
|
|
||||||
writer.Close()
|
|
||||||
req := httptest.NewRequest(MethodPost, "/test", body)
|
req := httptest.NewRequest(MethodPost, "/test", body)
|
||||||
req.Header.Set("Content-Type", fmt.Sprintf("multipart/form-data; boundary=%s", writer.Boundary()))
|
req.Header.Set("Content-Type", fmt.Sprintf("multipart/form-data; boundary=%s", writer.Boundary()))
|
||||||
req.Header.Set("Content-Length", strconv.Itoa(len(body.Bytes())))
|
req.Header.Set("Content-Length", strconv.Itoa(len(body.Bytes())))
|
||||||
|
@ -1240,7 +1241,7 @@ func Test_Ctx_IP(t *testing.T) {
|
||||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||||
defer app.ReleaseCtx(c)
|
defer app.ReleaseCtx(c)
|
||||||
|
|
||||||
// default behaviour will return the remote IP from the stack
|
// default behavior will return the remote IP from the stack
|
||||||
utils.AssertEqual(t, "0.0.0.0", c.IP())
|
utils.AssertEqual(t, "0.0.0.0", c.IP())
|
||||||
|
|
||||||
// X-Forwarded-For is set, but it is ignored because proxyHeader is not set
|
// X-Forwarded-For is set, but it is ignored because proxyHeader is not set
|
||||||
|
@ -1252,7 +1253,7 @@ func Test_Ctx_IP(t *testing.T) {
|
||||||
func Test_Ctx_IP_ProxyHeader(t *testing.T) {
|
func Test_Ctx_IP_ProxyHeader(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
// make sure that the same behaviour exists for different proxy header names
|
// make sure that the same behavior exists for different proxy header names
|
||||||
proxyHeaderNames := []string{"Real-Ip", HeaderXForwardedFor}
|
proxyHeaderNames := []string{"Real-Ip", HeaderXForwardedFor}
|
||||||
|
|
||||||
for _, proxyHeaderName := range proxyHeaderNames {
|
for _, proxyHeaderName := range proxyHeaderNames {
|
||||||
|
@ -1286,7 +1287,7 @@ func Test_Ctx_IP_ProxyHeader(t *testing.T) {
|
||||||
func Test_Ctx_IP_ProxyHeader_With_IP_Validation(t *testing.T) {
|
func Test_Ctx_IP_ProxyHeader_With_IP_Validation(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
// make sure that the same behaviour exists for different proxy header names
|
// make sure that the same behavior exists for different proxy header names
|
||||||
proxyHeaderNames := []string{"Real-Ip", HeaderXForwardedFor}
|
proxyHeaderNames := []string{"Real-Ip", HeaderXForwardedFor}
|
||||||
|
|
||||||
for _, proxyHeaderName := range proxyHeaderNames {
|
for _, proxyHeaderName := range proxyHeaderNames {
|
||||||
|
@ -1625,35 +1626,43 @@ func Test_Ctx_ClientHelloInfo(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
// Test without TLS handler
|
// Test without TLS handler
|
||||||
resp, _ := app.Test(httptest.NewRequest(MethodGet, "/ServerName", nil))
|
resp, err := app.Test(httptest.NewRequest(MethodGet, "/ServerName", nil))
|
||||||
body, _ := io.ReadAll(resp.Body)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, []byte("ClientHelloInfo is nil"), body)
|
utils.AssertEqual(t, []byte("ClientHelloInfo is nil"), body)
|
||||||
|
|
||||||
// Test with TLS Handler
|
// Test with TLS Handler
|
||||||
const (
|
const (
|
||||||
PSSWithSHA256 = 0x0804
|
pssWithSHA256 = 0x0804
|
||||||
VersionTLS13 = 0x0304
|
versionTLS13 = 0x0304
|
||||||
)
|
)
|
||||||
app.tlsHandler = &TLSHandler{clientHelloInfo: &tls.ClientHelloInfo{
|
app.tlsHandler = &TLSHandler{clientHelloInfo: &tls.ClientHelloInfo{
|
||||||
ServerName: "example.golang",
|
ServerName: "example.golang",
|
||||||
SignatureSchemes: []tls.SignatureScheme{PSSWithSHA256},
|
SignatureSchemes: []tls.SignatureScheme{pssWithSHA256},
|
||||||
SupportedVersions: []uint16{VersionTLS13},
|
SupportedVersions: []uint16{versionTLS13},
|
||||||
}}
|
}}
|
||||||
|
|
||||||
// Test ServerName
|
// Test ServerName
|
||||||
resp, _ = app.Test(httptest.NewRequest(MethodGet, "/ServerName", nil))
|
resp, err = app.Test(httptest.NewRequest(MethodGet, "/ServerName", nil))
|
||||||
body, _ = io.ReadAll(resp.Body)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
body, err = io.ReadAll(resp.Body)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, []byte("example.golang"), body)
|
utils.AssertEqual(t, []byte("example.golang"), body)
|
||||||
|
|
||||||
// Test SignatureSchemes
|
// Test SignatureSchemes
|
||||||
resp, _ = app.Test(httptest.NewRequest(MethodGet, "/SignatureSchemes", nil))
|
resp, err = app.Test(httptest.NewRequest(MethodGet, "/SignatureSchemes", nil))
|
||||||
body, _ = io.ReadAll(resp.Body)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, "["+strconv.Itoa(PSSWithSHA256)+"]", string(body))
|
body, err = io.ReadAll(resp.Body)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
utils.AssertEqual(t, "["+strconv.Itoa(pssWithSHA256)+"]", string(body))
|
||||||
|
|
||||||
// Test SupportedVersions
|
// Test SupportedVersions
|
||||||
resp, _ = app.Test(httptest.NewRequest(MethodGet, "/SupportedVersions", nil))
|
resp, err = app.Test(httptest.NewRequest(MethodGet, "/SupportedVersions", nil))
|
||||||
body, _ = io.ReadAll(resp.Body)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, "["+strconv.Itoa(VersionTLS13)+"]", string(body))
|
body, err = io.ReadAll(resp.Body)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
utils.AssertEqual(t, "["+strconv.Itoa(versionTLS13)+"]", string(body))
|
||||||
}
|
}
|
||||||
|
|
||||||
// go test -run Test_Ctx_InvalidMethod
|
// go test -run Test_Ctx_InvalidMethod
|
||||||
|
@ -1688,10 +1697,9 @@ func Test_Ctx_MultipartForm(t *testing.T) {
|
||||||
|
|
||||||
body := &bytes.Buffer{}
|
body := &bytes.Buffer{}
|
||||||
writer := multipart.NewWriter(body)
|
writer := multipart.NewWriter(body)
|
||||||
|
|
||||||
utils.AssertEqual(t, nil, writer.WriteField("name", "john"))
|
utils.AssertEqual(t, nil, writer.WriteField("name", "john"))
|
||||||
|
utils.AssertEqual(t, nil, writer.Close())
|
||||||
|
|
||||||
writer.Close()
|
|
||||||
req := httptest.NewRequest(MethodPost, "/test", body)
|
req := httptest.NewRequest(MethodPost, "/test", body)
|
||||||
req.Header.Set(HeaderContentType, fmt.Sprintf("multipart/form-data; boundary=%s", writer.Boundary()))
|
req.Header.Set(HeaderContentType, fmt.Sprintf("multipart/form-data; boundary=%s", writer.Boundary()))
|
||||||
req.Header.Set(HeaderContentLength, strconv.Itoa(len(body.Bytes())))
|
req.Header.Set(HeaderContentLength, strconv.Itoa(len(body.Bytes())))
|
||||||
|
@ -1706,8 +1714,8 @@ func Benchmark_Ctx_MultipartForm(b *testing.B) {
|
||||||
app := New()
|
app := New()
|
||||||
|
|
||||||
app.Post("/", func(c *Ctx) error {
|
app.Post("/", func(c *Ctx) error {
|
||||||
_, _ = c.MultipartForm()
|
_, err := c.MultipartForm()
|
||||||
return nil
|
return err
|
||||||
})
|
})
|
||||||
|
|
||||||
c := &fasthttp.RequestCtx{}
|
c := &fasthttp.RequestCtx{}
|
||||||
|
@ -1889,11 +1897,16 @@ func Benchmark_Ctx_AllParams(b *testing.B) {
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
res = c.AllParams()
|
res = c.AllParams()
|
||||||
}
|
}
|
||||||
utils.AssertEqual(b, map[string]string{"param1": "john",
|
utils.AssertEqual(
|
||||||
|
b,
|
||||||
|
map[string]string{
|
||||||
|
"param1": "john",
|
||||||
"param2": "doe",
|
"param2": "doe",
|
||||||
"param3": "is",
|
"param3": "is",
|
||||||
"param4": "awesome"},
|
"param4": "awesome",
|
||||||
res)
|
},
|
||||||
|
res,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// go test -v -run=^$ -bench=Benchmark_Ctx_ParamsParse -benchmem -count=4
|
// go test -v -run=^$ -bench=Benchmark_Ctx_ParamsParse -benchmem -count=4
|
||||||
|
@ -1964,31 +1977,31 @@ func Test_Ctx_Protocol(t *testing.T) {
|
||||||
|
|
||||||
c := app.AcquireCtx(freq)
|
c := app.AcquireCtx(freq)
|
||||||
defer app.ReleaseCtx(c)
|
defer app.ReleaseCtx(c)
|
||||||
c.Request().Header.Set(HeaderXForwardedProto, "https")
|
c.Request().Header.Set(HeaderXForwardedProto, schemeHTTPS)
|
||||||
utils.AssertEqual(t, "https", c.Protocol())
|
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
|
||||||
c.Request().Header.Reset()
|
c.Request().Header.Reset()
|
||||||
|
|
||||||
c.Request().Header.Set(HeaderXForwardedProtocol, "https")
|
c.Request().Header.Set(HeaderXForwardedProtocol, schemeHTTPS)
|
||||||
utils.AssertEqual(t, "https", c.Protocol())
|
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
|
||||||
c.Request().Header.Reset()
|
c.Request().Header.Reset()
|
||||||
|
|
||||||
c.Request().Header.Set(HeaderXForwardedProto, "https, http")
|
c.Request().Header.Set(HeaderXForwardedProto, "https, http")
|
||||||
utils.AssertEqual(t, "https", c.Protocol())
|
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
|
||||||
c.Request().Header.Reset()
|
c.Request().Header.Reset()
|
||||||
|
|
||||||
c.Request().Header.Set(HeaderXForwardedProtocol, "https, http")
|
c.Request().Header.Set(HeaderXForwardedProtocol, "https, http")
|
||||||
utils.AssertEqual(t, "https", c.Protocol())
|
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
|
||||||
c.Request().Header.Reset()
|
c.Request().Header.Reset()
|
||||||
|
|
||||||
c.Request().Header.Set(HeaderXForwardedSsl, "on")
|
c.Request().Header.Set(HeaderXForwardedSsl, "on")
|
||||||
utils.AssertEqual(t, "https", c.Protocol())
|
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
|
||||||
c.Request().Header.Reset()
|
c.Request().Header.Reset()
|
||||||
|
|
||||||
c.Request().Header.Set(HeaderXUrlScheme, "https")
|
c.Request().Header.Set(HeaderXUrlScheme, schemeHTTPS)
|
||||||
utils.AssertEqual(t, "https", c.Protocol())
|
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
|
||||||
c.Request().Header.Reset()
|
c.Request().Header.Reset()
|
||||||
|
|
||||||
utils.AssertEqual(t, "http", c.Protocol())
|
utils.AssertEqual(t, schemeHTTP, c.Protocol())
|
||||||
}
|
}
|
||||||
|
|
||||||
// go test -v -run=^$ -bench=Benchmark_Ctx_Protocol -benchmem -count=4
|
// go test -v -run=^$ -bench=Benchmark_Ctx_Protocol -benchmem -count=4
|
||||||
|
@ -2002,7 +2015,7 @@ func Benchmark_Ctx_Protocol(b *testing.B) {
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
res = c.Protocol()
|
res = c.Protocol()
|
||||||
}
|
}
|
||||||
utils.AssertEqual(b, "http", res)
|
utils.AssertEqual(b, schemeHTTP, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
// go test -run Test_Ctx_Protocol_TrustedProxy
|
// go test -run Test_Ctx_Protocol_TrustedProxy
|
||||||
|
@ -2012,23 +2025,23 @@ func Test_Ctx_Protocol_TrustedProxy(t *testing.T) {
|
||||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||||
defer app.ReleaseCtx(c)
|
defer app.ReleaseCtx(c)
|
||||||
|
|
||||||
c.Request().Header.Set(HeaderXForwardedProto, "https")
|
c.Request().Header.Set(HeaderXForwardedProto, schemeHTTPS)
|
||||||
utils.AssertEqual(t, "https", c.Protocol())
|
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
|
||||||
c.Request().Header.Reset()
|
c.Request().Header.Reset()
|
||||||
|
|
||||||
c.Request().Header.Set(HeaderXForwardedProtocol, "https")
|
c.Request().Header.Set(HeaderXForwardedProtocol, schemeHTTPS)
|
||||||
utils.AssertEqual(t, "https", c.Protocol())
|
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
|
||||||
c.Request().Header.Reset()
|
c.Request().Header.Reset()
|
||||||
|
|
||||||
c.Request().Header.Set(HeaderXForwardedSsl, "on")
|
c.Request().Header.Set(HeaderXForwardedSsl, "on")
|
||||||
utils.AssertEqual(t, "https", c.Protocol())
|
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
|
||||||
c.Request().Header.Reset()
|
c.Request().Header.Reset()
|
||||||
|
|
||||||
c.Request().Header.Set(HeaderXUrlScheme, "https")
|
c.Request().Header.Set(HeaderXUrlScheme, schemeHTTPS)
|
||||||
utils.AssertEqual(t, "https", c.Protocol())
|
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
|
||||||
c.Request().Header.Reset()
|
c.Request().Header.Reset()
|
||||||
|
|
||||||
utils.AssertEqual(t, "http", c.Protocol())
|
utils.AssertEqual(t, schemeHTTP, c.Protocol())
|
||||||
}
|
}
|
||||||
|
|
||||||
// go test -run Test_Ctx_Protocol_TrustedProxyRange
|
// go test -run Test_Ctx_Protocol_TrustedProxyRange
|
||||||
|
@ -2038,23 +2051,23 @@ func Test_Ctx_Protocol_TrustedProxyRange(t *testing.T) {
|
||||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||||
defer app.ReleaseCtx(c)
|
defer app.ReleaseCtx(c)
|
||||||
|
|
||||||
c.Request().Header.Set(HeaderXForwardedProto, "https")
|
c.Request().Header.Set(HeaderXForwardedProto, schemeHTTPS)
|
||||||
utils.AssertEqual(t, "https", c.Protocol())
|
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
|
||||||
c.Request().Header.Reset()
|
c.Request().Header.Reset()
|
||||||
|
|
||||||
c.Request().Header.Set(HeaderXForwardedProtocol, "https")
|
c.Request().Header.Set(HeaderXForwardedProtocol, schemeHTTPS)
|
||||||
utils.AssertEqual(t, "https", c.Protocol())
|
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
|
||||||
c.Request().Header.Reset()
|
c.Request().Header.Reset()
|
||||||
|
|
||||||
c.Request().Header.Set(HeaderXForwardedSsl, "on")
|
c.Request().Header.Set(HeaderXForwardedSsl, "on")
|
||||||
utils.AssertEqual(t, "https", c.Protocol())
|
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
|
||||||
c.Request().Header.Reset()
|
c.Request().Header.Reset()
|
||||||
|
|
||||||
c.Request().Header.Set(HeaderXUrlScheme, "https")
|
c.Request().Header.Set(HeaderXUrlScheme, schemeHTTPS)
|
||||||
utils.AssertEqual(t, "https", c.Protocol())
|
utils.AssertEqual(t, schemeHTTPS, c.Protocol())
|
||||||
c.Request().Header.Reset()
|
c.Request().Header.Reset()
|
||||||
|
|
||||||
utils.AssertEqual(t, "http", c.Protocol())
|
utils.AssertEqual(t, schemeHTTP, c.Protocol())
|
||||||
}
|
}
|
||||||
|
|
||||||
// go test -run Test_Ctx_Protocol_UntrustedProxyRange
|
// go test -run Test_Ctx_Protocol_UntrustedProxyRange
|
||||||
|
@ -2064,23 +2077,23 @@ func Test_Ctx_Protocol_UntrustedProxyRange(t *testing.T) {
|
||||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||||
defer app.ReleaseCtx(c)
|
defer app.ReleaseCtx(c)
|
||||||
|
|
||||||
c.Request().Header.Set(HeaderXForwardedProto, "https")
|
c.Request().Header.Set(HeaderXForwardedProto, schemeHTTPS)
|
||||||
utils.AssertEqual(t, "http", c.Protocol())
|
utils.AssertEqual(t, schemeHTTP, c.Protocol())
|
||||||
c.Request().Header.Reset()
|
c.Request().Header.Reset()
|
||||||
|
|
||||||
c.Request().Header.Set(HeaderXForwardedProtocol, "https")
|
c.Request().Header.Set(HeaderXForwardedProtocol, schemeHTTPS)
|
||||||
utils.AssertEqual(t, "http", c.Protocol())
|
utils.AssertEqual(t, schemeHTTP, c.Protocol())
|
||||||
c.Request().Header.Reset()
|
c.Request().Header.Reset()
|
||||||
|
|
||||||
c.Request().Header.Set(HeaderXForwardedSsl, "on")
|
c.Request().Header.Set(HeaderXForwardedSsl, "on")
|
||||||
utils.AssertEqual(t, "http", c.Protocol())
|
utils.AssertEqual(t, schemeHTTP, c.Protocol())
|
||||||
c.Request().Header.Reset()
|
c.Request().Header.Reset()
|
||||||
|
|
||||||
c.Request().Header.Set(HeaderXUrlScheme, "https")
|
c.Request().Header.Set(HeaderXUrlScheme, schemeHTTPS)
|
||||||
utils.AssertEqual(t, "http", c.Protocol())
|
utils.AssertEqual(t, schemeHTTP, c.Protocol())
|
||||||
c.Request().Header.Reset()
|
c.Request().Header.Reset()
|
||||||
|
|
||||||
utils.AssertEqual(t, "http", c.Protocol())
|
utils.AssertEqual(t, schemeHTTP, c.Protocol())
|
||||||
}
|
}
|
||||||
|
|
||||||
// go test -run Test_Ctx_Protocol_UnTrustedProxy
|
// go test -run Test_Ctx_Protocol_UnTrustedProxy
|
||||||
|
@ -2090,23 +2103,23 @@ func Test_Ctx_Protocol_UnTrustedProxy(t *testing.T) {
|
||||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||||
defer app.ReleaseCtx(c)
|
defer app.ReleaseCtx(c)
|
||||||
|
|
||||||
c.Request().Header.Set(HeaderXForwardedProto, "https")
|
c.Request().Header.Set(HeaderXForwardedProto, schemeHTTPS)
|
||||||
utils.AssertEqual(t, "http", c.Protocol())
|
utils.AssertEqual(t, schemeHTTP, c.Protocol())
|
||||||
c.Request().Header.Reset()
|
c.Request().Header.Reset()
|
||||||
|
|
||||||
c.Request().Header.Set(HeaderXForwardedProtocol, "https")
|
c.Request().Header.Set(HeaderXForwardedProtocol, schemeHTTPS)
|
||||||
utils.AssertEqual(t, "http", c.Protocol())
|
utils.AssertEqual(t, schemeHTTP, c.Protocol())
|
||||||
c.Request().Header.Reset()
|
c.Request().Header.Reset()
|
||||||
|
|
||||||
c.Request().Header.Set(HeaderXForwardedSsl, "on")
|
c.Request().Header.Set(HeaderXForwardedSsl, "on")
|
||||||
utils.AssertEqual(t, "http", c.Protocol())
|
utils.AssertEqual(t, schemeHTTP, c.Protocol())
|
||||||
c.Request().Header.Reset()
|
c.Request().Header.Reset()
|
||||||
|
|
||||||
c.Request().Header.Set(HeaderXUrlScheme, "https")
|
c.Request().Header.Set(HeaderXUrlScheme, schemeHTTPS)
|
||||||
utils.AssertEqual(t, "http", c.Protocol())
|
utils.AssertEqual(t, schemeHTTP, c.Protocol())
|
||||||
c.Request().Header.Reset()
|
c.Request().Header.Reset()
|
||||||
|
|
||||||
utils.AssertEqual(t, "http", c.Protocol())
|
utils.AssertEqual(t, schemeHTTP, c.Protocol())
|
||||||
}
|
}
|
||||||
|
|
||||||
// go test -run Test_Ctx_Query
|
// go test -run Test_Ctx_Query
|
||||||
|
@ -2224,7 +2237,12 @@ func Test_Ctx_SaveFile(t *testing.T) {
|
||||||
tempFile, err := os.CreateTemp(os.TempDir(), "test-")
|
tempFile, err := os.CreateTemp(os.TempDir(), "test-")
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
defer os.Remove(tempFile.Name())
|
defer func(file *os.File) {
|
||||||
|
err := file.Close()
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
err = os.Remove(file.Name())
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
}(tempFile)
|
||||||
err = c.SaveFile(fh, tempFile.Name())
|
err = c.SaveFile(fh, tempFile.Name())
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
|
@ -2242,7 +2260,7 @@ func Test_Ctx_SaveFile(t *testing.T) {
|
||||||
|
|
||||||
_, err = ioWriter.Write([]byte("hello world"))
|
_, err = ioWriter.Write([]byte("hello world"))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
writer.Close()
|
utils.AssertEqual(t, nil, writer.Close())
|
||||||
|
|
||||||
req := httptest.NewRequest(MethodPost, "/test", body)
|
req := httptest.NewRequest(MethodPost, "/test", body)
|
||||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
@ -2284,7 +2302,7 @@ func Test_Ctx_SaveFileToStorage(t *testing.T) {
|
||||||
|
|
||||||
_, err = ioWriter.Write([]byte("hello world"))
|
_, err = ioWriter.Write([]byte("hello world"))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
writer.Close()
|
utils.AssertEqual(t, nil, writer.Close())
|
||||||
|
|
||||||
req := httptest.NewRequest(MethodPost, "/test", body)
|
req := httptest.NewRequest(MethodPost, "/test", body)
|
||||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
@ -2370,7 +2388,9 @@ func Test_Ctx_Download(t *testing.T) {
|
||||||
|
|
||||||
f, err := os.Open("./ctx.go")
|
f, err := os.Open("./ctx.go")
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
defer f.Close()
|
defer func() {
|
||||||
|
utils.AssertEqual(t, nil, f.Close())
|
||||||
|
}()
|
||||||
|
|
||||||
expect, err := io.ReadAll(f)
|
expect, err := io.ReadAll(f)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
@ -2389,7 +2409,9 @@ func Test_Ctx_SendFile(t *testing.T) {
|
||||||
// fetch file content
|
// fetch file content
|
||||||
f, err := os.Open("./ctx.go")
|
f, err := os.Open("./ctx.go")
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
defer f.Close()
|
defer func() {
|
||||||
|
utils.AssertEqual(t, nil, f.Close())
|
||||||
|
}()
|
||||||
expectFileContent, err := io.ReadAll(f)
|
expectFileContent, err := io.ReadAll(f)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
// fetch file info for the not modified test case
|
// fetch file info for the not modified test case
|
||||||
|
@ -2435,7 +2457,7 @@ func Test_Ctx_SendFile_404(t *testing.T) {
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, StatusNotFound, resp.StatusCode)
|
utils.AssertEqual(t, StatusNotFound, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
@ -2473,11 +2495,11 @@ func Test_Ctx_SendFile_Immutable(t *testing.T) {
|
||||||
for _, endpoint := range endpointsForTest {
|
for _, endpoint := range endpointsForTest {
|
||||||
t.Run(endpoint, func(t *testing.T) {
|
t.Run(endpoint, func(t *testing.T) {
|
||||||
// 1st try
|
// 1st try
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", endpoint, nil))
|
resp, err := app.Test(httptest.NewRequest(MethodGet, endpoint, nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, StatusOK, resp.StatusCode)
|
utils.AssertEqual(t, StatusOK, resp.StatusCode)
|
||||||
// 2nd try
|
// 2nd try
|
||||||
resp, err = app.Test(httptest.NewRequest("GET", endpoint, nil))
|
resp, err = app.Test(httptest.NewRequest(MethodGet, endpoint, nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, StatusOK, resp.StatusCode)
|
utils.AssertEqual(t, StatusOK, resp.StatusCode)
|
||||||
})
|
})
|
||||||
|
@ -2495,9 +2517,9 @@ func Test_Ctx_SendFile_RestoreOriginalURL(t *testing.T) {
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
|
|
||||||
_, err1 := app.Test(httptest.NewRequest("GET", "/?test=true", nil))
|
_, err1 := app.Test(httptest.NewRequest(MethodGet, "/?test=true", nil))
|
||||||
// second request required to confirm with zero allocation
|
// second request required to confirm with zero allocation
|
||||||
_, err2 := app.Test(httptest.NewRequest("GET", "/?test=true", nil))
|
_, err2 := app.Test(httptest.NewRequest(MethodGet, "/?test=true", nil))
|
||||||
|
|
||||||
utils.AssertEqual(t, nil, err1)
|
utils.AssertEqual(t, nil, err1)
|
||||||
utils.AssertEqual(t, nil, err2)
|
utils.AssertEqual(t, nil, err2)
|
||||||
|
@ -2893,12 +2915,12 @@ func Test_Ctx_Render(t *testing.T) {
|
||||||
err := c.Render("./.github/testdata/index.tmpl", Map{
|
err := c.Render("./.github/testdata/index.tmpl", Map{
|
||||||
"Title": "Hello, World!",
|
"Title": "Hello, World!",
|
||||||
})
|
})
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
buf := bytebufferpool.Get()
|
buf := bytebufferpool.Get()
|
||||||
_, _ = buf.WriteString("overwrite")
|
_, _ = buf.WriteString("overwrite") //nolint:errcheck // This will never fail
|
||||||
defer bytebufferpool.Put(buf)
|
defer bytebufferpool.Put(buf)
|
||||||
|
|
||||||
utils.AssertEqual(t, nil, err)
|
|
||||||
utils.AssertEqual(t, "<h1>Hello, World!</h1>", string(c.Response().Body()))
|
utils.AssertEqual(t, "<h1>Hello, World!</h1>", string(c.Response().Body()))
|
||||||
|
|
||||||
err = c.Render("./.github/testdata/template-non-exists.html", nil)
|
err = c.Render("./.github/testdata/template-non-exists.html", nil)
|
||||||
|
@ -2918,12 +2940,12 @@ func Test_Ctx_RenderWithoutLocals(t *testing.T) {
|
||||||
c.Locals("Title", "Hello, World!")
|
c.Locals("Title", "Hello, World!")
|
||||||
defer app.ReleaseCtx(c)
|
defer app.ReleaseCtx(c)
|
||||||
err := c.Render("./.github/testdata/index.tmpl", Map{})
|
err := c.Render("./.github/testdata/index.tmpl", Map{})
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
buf := bytebufferpool.Get()
|
buf := bytebufferpool.Get()
|
||||||
_, _ = buf.WriteString("overwrite")
|
_, _ = buf.WriteString("overwrite") //nolint:errcheck // This will never fail
|
||||||
defer bytebufferpool.Put(buf)
|
defer bytebufferpool.Put(buf)
|
||||||
|
|
||||||
utils.AssertEqual(t, nil, err)
|
|
||||||
utils.AssertEqual(t, "<h1><no value></h1>", string(c.Response().Body()))
|
utils.AssertEqual(t, "<h1><no value></h1>", string(c.Response().Body()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2937,14 +2959,13 @@ func Test_Ctx_RenderWithLocals(t *testing.T) {
|
||||||
c.Locals("Title", "Hello, World!")
|
c.Locals("Title", "Hello, World!")
|
||||||
defer app.ReleaseCtx(c)
|
defer app.ReleaseCtx(c)
|
||||||
err := c.Render("./.github/testdata/index.tmpl", Map{})
|
err := c.Render("./.github/testdata/index.tmpl", Map{})
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
buf := bytebufferpool.Get()
|
buf := bytebufferpool.Get()
|
||||||
_, _ = buf.WriteString("overwrite")
|
_, _ = buf.WriteString("overwrite") //nolint:errcheck // This will never fail
|
||||||
defer bytebufferpool.Put(buf)
|
defer bytebufferpool.Put(buf)
|
||||||
|
|
||||||
utils.AssertEqual(t, nil, err)
|
|
||||||
utils.AssertEqual(t, "<h1>Hello, World!</h1>", string(c.Response().Body()))
|
utils.AssertEqual(t, "<h1>Hello, World!</h1>", string(c.Response().Body()))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_Ctx_RenderWithBind(t *testing.T) {
|
func Test_Ctx_RenderWithBind(t *testing.T) {
|
||||||
|
@ -2959,14 +2980,13 @@ func Test_Ctx_RenderWithBind(t *testing.T) {
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
defer app.ReleaseCtx(c)
|
defer app.ReleaseCtx(c)
|
||||||
err = c.Render("./.github/testdata/index.tmpl", Map{})
|
err = c.Render("./.github/testdata/index.tmpl", Map{})
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
buf := bytebufferpool.Get()
|
buf := bytebufferpool.Get()
|
||||||
_, _ = buf.WriteString("overwrite")
|
_, _ = buf.WriteString("overwrite") //nolint:errcheck // This will never fail
|
||||||
defer bytebufferpool.Put(buf)
|
defer bytebufferpool.Put(buf)
|
||||||
|
|
||||||
utils.AssertEqual(t, nil, err)
|
|
||||||
utils.AssertEqual(t, "<h1>Hello, World!</h1>", string(c.Response().Body()))
|
utils.AssertEqual(t, "<h1>Hello, World!</h1>", string(c.Response().Body()))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_Ctx_RenderWithOverwrittenBind(t *testing.T) {
|
func Test_Ctx_RenderWithOverwrittenBind(t *testing.T) {
|
||||||
|
@ -2982,12 +3002,12 @@ func Test_Ctx_RenderWithOverwrittenBind(t *testing.T) {
|
||||||
err = c.Render("./.github/testdata/index.tmpl", Map{
|
err = c.Render("./.github/testdata/index.tmpl", Map{
|
||||||
"Title": "Hello from Fiber!",
|
"Title": "Hello from Fiber!",
|
||||||
})
|
})
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
buf := bytebufferpool.Get()
|
buf := bytebufferpool.Get()
|
||||||
_, _ = buf.WriteString("overwrite")
|
_, _ = buf.WriteString("overwrite") //nolint:errcheck // This will never fail
|
||||||
defer bytebufferpool.Put(buf)
|
defer bytebufferpool.Put(buf)
|
||||||
|
|
||||||
utils.AssertEqual(t, nil, err)
|
|
||||||
utils.AssertEqual(t, "<h1>Hello from Fiber!</h1>", string(c.Response().Body()))
|
utils.AssertEqual(t, "<h1>Hello from Fiber!</h1>", string(c.Response().Body()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3005,13 +3025,12 @@ func Test_Ctx_RenderWithBindLocals(t *testing.T) {
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
c.Locals("Summary", "Test")
|
c.Locals("Summary", "Test")
|
||||||
|
|
||||||
defer app.ReleaseCtx(c)
|
defer app.ReleaseCtx(c)
|
||||||
|
|
||||||
err = c.Render("./.github/testdata/template.tmpl", Map{})
|
err = c.Render("./.github/testdata/template.tmpl", Map{})
|
||||||
|
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, "<h1>Hello, World! Test</h1>", string(c.Response().Body()))
|
|
||||||
|
|
||||||
|
utils.AssertEqual(t, "<h1>Hello, World! Test</h1>", string(c.Response().Body()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_Ctx_RenderWithLocalsAndBinding(t *testing.T) {
|
func Test_Ctx_RenderWithLocalsAndBinding(t *testing.T) {
|
||||||
|
@ -3027,11 +3046,12 @@ func Test_Ctx_RenderWithLocalsAndBinding(t *testing.T) {
|
||||||
|
|
||||||
c.Locals("Title", "This is a test.")
|
c.Locals("Title", "This is a test.")
|
||||||
defer app.ReleaseCtx(c)
|
defer app.ReleaseCtx(c)
|
||||||
|
|
||||||
err = c.Render("index.tmpl", Map{
|
err = c.Render("index.tmpl", Map{
|
||||||
"Title": "Hello, World!",
|
"Title": "Hello, World!",
|
||||||
})
|
})
|
||||||
|
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
utils.AssertEqual(t, "<h1>Hello, World!</h1>", string(c.Response().Body()))
|
utils.AssertEqual(t, "<h1>Hello, World!</h1>", string(c.Response().Body()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3060,8 +3080,8 @@ func Benchmark_Ctx_RenderWithLocalsAndBinding(b *testing.B) {
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
err = c.Render("template.tmpl", Map{})
|
err = c.Render("template.tmpl", Map{})
|
||||||
}
|
}
|
||||||
|
|
||||||
utils.AssertEqual(b, nil, err)
|
utils.AssertEqual(b, nil, err)
|
||||||
|
|
||||||
utils.AssertEqual(b, "<h1>Hello, World! Test</h1>", string(c.Response().Body()))
|
utils.AssertEqual(b, "<h1>Hello, World! Test</h1>", string(c.Response().Body()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3083,8 +3103,8 @@ func Benchmark_Ctx_RedirectToRoute(b *testing.B) {
|
||||||
"name": "fiber",
|
"name": "fiber",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
utils.AssertEqual(b, nil, err)
|
utils.AssertEqual(b, nil, err)
|
||||||
|
|
||||||
utils.AssertEqual(b, 302, c.Response().StatusCode())
|
utils.AssertEqual(b, 302, c.Response().StatusCode())
|
||||||
utils.AssertEqual(b, "/user/fiber", string(c.Response().Header.Peek(HeaderLocation)))
|
utils.AssertEqual(b, "/user/fiber", string(c.Response().Header.Peek(HeaderLocation)))
|
||||||
}
|
}
|
||||||
|
@ -3108,8 +3128,8 @@ func Benchmark_Ctx_RedirectToRouteWithQueries(b *testing.B) {
|
||||||
"queries": map[string]string{"a": "a", "b": "b"},
|
"queries": map[string]string{"a": "a", "b": "b"},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
utils.AssertEqual(b, nil, err)
|
utils.AssertEqual(b, nil, err)
|
||||||
|
|
||||||
utils.AssertEqual(b, 302, c.Response().StatusCode())
|
utils.AssertEqual(b, 302, c.Response().StatusCode())
|
||||||
// analysis of query parameters with url parsing, since a map pass is always randomly ordered
|
// analysis of query parameters with url parsing, since a map pass is always randomly ordered
|
||||||
location, err := url.Parse(string(c.Response().Header.Peek(HeaderLocation)))
|
location, err := url.Parse(string(c.Response().Header.Peek(HeaderLocation)))
|
||||||
|
@ -3138,8 +3158,8 @@ func Benchmark_Ctx_RenderLocals(b *testing.B) {
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
err = c.Render("index.tmpl", Map{})
|
err = c.Render("index.tmpl", Map{})
|
||||||
}
|
}
|
||||||
|
|
||||||
utils.AssertEqual(b, nil, err)
|
utils.AssertEqual(b, nil, err)
|
||||||
|
|
||||||
utils.AssertEqual(b, "<h1>Hello, World!</h1>", string(c.Response().Body()))
|
utils.AssertEqual(b, "<h1>Hello, World!</h1>", string(c.Response().Body()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3164,8 +3184,8 @@ func Benchmark_Ctx_RenderBind(b *testing.B) {
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
err = c.Render("index.tmpl", Map{})
|
err = c.Render("index.tmpl", Map{})
|
||||||
}
|
}
|
||||||
|
|
||||||
utils.AssertEqual(b, nil, err)
|
utils.AssertEqual(b, nil, err)
|
||||||
|
|
||||||
utils.AssertEqual(b, "<h1>Hello, World!</h1>", string(c.Response().Body()))
|
utils.AssertEqual(b, "<h1>Hello, World!</h1>", string(c.Response().Body()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3191,8 +3211,7 @@ func Test_Ctx_RestartRouting(t *testing.T) {
|
||||||
func Test_Ctx_RestartRoutingWithChangedPath(t *testing.T) {
|
func Test_Ctx_RestartRoutingWithChangedPath(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
app := New()
|
app := New()
|
||||||
executedOldHandler := false
|
var executedOldHandler, executedNewHandler bool
|
||||||
executedNewHandler := false
|
|
||||||
|
|
||||||
app.Get("/old", func(c *Ctx) error {
|
app.Get("/old", func(c *Ctx) error {
|
||||||
c.Path("/new")
|
c.Path("/new")
|
||||||
|
@ -3242,10 +3261,18 @@ type testTemplateEngine struct {
|
||||||
|
|
||||||
func (t *testTemplateEngine) Render(w io.Writer, name string, bind interface{}, layout ...string) error {
|
func (t *testTemplateEngine) Render(w io.Writer, name string, bind interface{}, layout ...string) error {
|
||||||
if len(layout) == 0 {
|
if len(layout) == 0 {
|
||||||
return t.templates.ExecuteTemplate(w, name, bind)
|
if err := t.templates.ExecuteTemplate(w, name, bind); err != nil {
|
||||||
|
return fmt.Errorf("failed to execute template without layout: %w", err)
|
||||||
}
|
}
|
||||||
_ = t.templates.ExecuteTemplate(w, name, bind)
|
return nil
|
||||||
return t.templates.ExecuteTemplate(w, layout[0], bind)
|
}
|
||||||
|
if err := t.templates.ExecuteTemplate(w, name, bind); err != nil {
|
||||||
|
return fmt.Errorf("failed to execute template: %w", err)
|
||||||
|
}
|
||||||
|
if err := t.templates.ExecuteTemplate(w, layout[0], bind); err != nil {
|
||||||
|
return fmt.Errorf("failed to execute template with layout: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *testTemplateEngine) Load() error {
|
func (t *testTemplateEngine) Load() error {
|
||||||
|
@ -3324,7 +3351,6 @@ func Benchmark_Ctx_Get_Location_From_Route(b *testing.B) {
|
||||||
}
|
}
|
||||||
utils.AssertEqual(b, "/user/fiber", location)
|
utils.AssertEqual(b, "/user/fiber", location)
|
||||||
utils.AssertEqual(b, nil, err)
|
utils.AssertEqual(b, nil, err)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// go test -run Test_Ctx_Get_Location_From_Route_name
|
// go test -run Test_Ctx_Get_Location_From_Route_name
|
||||||
|
@ -3407,11 +3433,11 @@ func Test_Ctx_Get_Location_From_Route_name_Optional_greedy_one_param(t *testing.
|
||||||
|
|
||||||
type errorTemplateEngine struct{}
|
type errorTemplateEngine struct{}
|
||||||
|
|
||||||
func (t errorTemplateEngine) Render(w io.Writer, name string, bind interface{}, layout ...string) error {
|
func (errorTemplateEngine) Render(_ io.Writer, _ string, _ interface{}, _ ...string) error {
|
||||||
return errors.New("errorTemplateEngine")
|
return errors.New("errorTemplateEngine")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t errorTemplateEngine) Load() error { return nil }
|
func (errorTemplateEngine) Load() error { return nil }
|
||||||
|
|
||||||
// go test -run Test_Ctx_Render_Engine_Error
|
// go test -run Test_Ctx_Render_Engine_Error
|
||||||
func Test_Ctx_Render_Engine_Error(t *testing.T) {
|
func Test_Ctx_Render_Engine_Error(t *testing.T) {
|
||||||
|
@ -3429,7 +3455,10 @@ func Test_Ctx_Render_Go_Template(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
file, err := os.CreateTemp(os.TempDir(), "fiber")
|
file, err := os.CreateTemp(os.TempDir(), "fiber")
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
defer os.Remove(file.Name())
|
defer func() {
|
||||||
|
err := os.Remove(file.Name())
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
}()
|
||||||
|
|
||||||
_, err = file.Write([]byte("template"))
|
_, err = file.Write([]byte("template"))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
@ -3821,7 +3850,7 @@ func Test_Ctx_QueryParser(t *testing.T) {
|
||||||
}
|
}
|
||||||
rq := new(RequiredQuery)
|
rq := new(RequiredQuery)
|
||||||
c.Request().URI().SetQueryString("")
|
c.Request().URI().SetQueryString("")
|
||||||
utils.AssertEqual(t, "name is empty", c.QueryParser(rq).Error())
|
utils.AssertEqual(t, "failed to decode: name is empty", c.QueryParser(rq).Error())
|
||||||
|
|
||||||
type ArrayQuery struct {
|
type ArrayQuery struct {
|
||||||
Data []string
|
Data []string
|
||||||
|
@ -3837,7 +3866,7 @@ func Test_Ctx_QueryParser_WithSetParserDecoder(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
type NonRFCTime time.Time
|
type NonRFCTime time.Time
|
||||||
|
|
||||||
NonRFCConverter := func(value string) reflect.Value {
|
nonRFCConverter := func(value string) reflect.Value {
|
||||||
if v, err := time.Parse("2006-01-02", value); err == nil {
|
if v, err := time.Parse("2006-01-02", value); err == nil {
|
||||||
return reflect.ValueOf(v)
|
return reflect.ValueOf(v)
|
||||||
}
|
}
|
||||||
|
@ -3846,7 +3875,7 @@ func Test_Ctx_QueryParser_WithSetParserDecoder(t *testing.T) {
|
||||||
|
|
||||||
nonRFCTime := ParserType{
|
nonRFCTime := ParserType{
|
||||||
Customtype: NonRFCTime{},
|
Customtype: NonRFCTime{},
|
||||||
Converter: NonRFCConverter,
|
Converter: nonRFCConverter,
|
||||||
}
|
}
|
||||||
|
|
||||||
SetParserDecoder(ParserConfig{
|
SetParserDecoder(ParserConfig{
|
||||||
|
@ -3872,7 +3901,6 @@ func Test_Ctx_QueryParser_WithSetParserDecoder(t *testing.T) {
|
||||||
|
|
||||||
c.Request().URI().SetQueryString("date=2021-04-10&title=CustomDateTest&Body=October")
|
c.Request().URI().SetQueryString("date=2021-04-10&title=CustomDateTest&Body=October")
|
||||||
utils.AssertEqual(t, nil, c.QueryParser(q))
|
utils.AssertEqual(t, nil, c.QueryParser(q))
|
||||||
fmt.Println(q.Date, "q.Date")
|
|
||||||
utils.AssertEqual(t, "CustomDateTest", q.Title)
|
utils.AssertEqual(t, "CustomDateTest", q.Title)
|
||||||
date := fmt.Sprintf("%v", q.Date)
|
date := fmt.Sprintf("%v", q.Date)
|
||||||
utils.AssertEqual(t, "{0 63753609600 <nil>}", date)
|
utils.AssertEqual(t, "{0 63753609600 <nil>}", date)
|
||||||
|
@ -3907,7 +3935,7 @@ func Test_Ctx_QueryParser_Schema(t *testing.T) {
|
||||||
|
|
||||||
c.Request().URI().SetQueryString("namex=tom&nested.age=10")
|
c.Request().URI().SetQueryString("namex=tom&nested.age=10")
|
||||||
q = new(Query1)
|
q = new(Query1)
|
||||||
utils.AssertEqual(t, "name is empty", c.QueryParser(q).Error())
|
utils.AssertEqual(t, "failed to decode: name is empty", c.QueryParser(q).Error())
|
||||||
|
|
||||||
c.Request().URI().SetQueryString("name=tom&nested.agex=10")
|
c.Request().URI().SetQueryString("name=tom&nested.agex=10")
|
||||||
q = new(Query1)
|
q = new(Query1)
|
||||||
|
@ -3915,7 +3943,7 @@ func Test_Ctx_QueryParser_Schema(t *testing.T) {
|
||||||
|
|
||||||
c.Request().URI().SetQueryString("name=tom&test.age=10")
|
c.Request().URI().SetQueryString("name=tom&test.age=10")
|
||||||
q = new(Query1)
|
q = new(Query1)
|
||||||
utils.AssertEqual(t, "nested is empty", c.QueryParser(q).Error())
|
utils.AssertEqual(t, "failed to decode: nested is empty", c.QueryParser(q).Error())
|
||||||
|
|
||||||
type Query2 struct {
|
type Query2 struct {
|
||||||
Name string `query:"name"`
|
Name string `query:"name"`
|
||||||
|
@ -3933,11 +3961,11 @@ func Test_Ctx_QueryParser_Schema(t *testing.T) {
|
||||||
|
|
||||||
c.Request().URI().SetQueryString("nested.agex=10")
|
c.Request().URI().SetQueryString("nested.agex=10")
|
||||||
q2 = new(Query2)
|
q2 = new(Query2)
|
||||||
utils.AssertEqual(t, "nested.age is empty", c.QueryParser(q2).Error())
|
utils.AssertEqual(t, "failed to decode: nested.age is empty", c.QueryParser(q2).Error())
|
||||||
|
|
||||||
c.Request().URI().SetQueryString("nested.agex=10")
|
c.Request().URI().SetQueryString("nested.agex=10")
|
||||||
q2 = new(Query2)
|
q2 = new(Query2)
|
||||||
utils.AssertEqual(t, "nested.age is empty", c.QueryParser(q2).Error())
|
utils.AssertEqual(t, "failed to decode: nested.age is empty", c.QueryParser(q2).Error())
|
||||||
|
|
||||||
type Node struct {
|
type Node struct {
|
||||||
Value int `query:"val,required"`
|
Value int `query:"val,required"`
|
||||||
|
@ -3951,7 +3979,7 @@ func Test_Ctx_QueryParser_Schema(t *testing.T) {
|
||||||
|
|
||||||
c.Request().URI().SetQueryString("next.val=2")
|
c.Request().URI().SetQueryString("next.val=2")
|
||||||
n = new(Node)
|
n = new(Node)
|
||||||
utils.AssertEqual(t, "val is empty", c.QueryParser(n).Error())
|
utils.AssertEqual(t, "failed to decode: val is empty", c.QueryParser(n).Error())
|
||||||
|
|
||||||
c.Request().URI().SetQueryString("val=3&next.value=2")
|
c.Request().URI().SetQueryString("val=3&next.value=2")
|
||||||
n = new(Node)
|
n = new(Node)
|
||||||
|
@ -4057,7 +4085,7 @@ func Test_Ctx_ReqHeaderParser(t *testing.T) {
|
||||||
}
|
}
|
||||||
rh := new(RequiredHeader)
|
rh := new(RequiredHeader)
|
||||||
c.Request().Header.Del("name")
|
c.Request().Header.Del("name")
|
||||||
utils.AssertEqual(t, "name is empty", c.ReqHeaderParser(rh).Error())
|
utils.AssertEqual(t, "failed to decode: name is empty", c.ReqHeaderParser(rh).Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
// go test -run Test_Ctx_ReqHeaderParser_WithSetParserDecoder -v
|
// go test -run Test_Ctx_ReqHeaderParser_WithSetParserDecoder -v
|
||||||
|
@ -4065,7 +4093,7 @@ func Test_Ctx_ReqHeaderParser_WithSetParserDecoder(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
type NonRFCTime time.Time
|
type NonRFCTime time.Time
|
||||||
|
|
||||||
NonRFCConverter := func(value string) reflect.Value {
|
nonRFCConverter := func(value string) reflect.Value {
|
||||||
if v, err := time.Parse("2006-01-02", value); err == nil {
|
if v, err := time.Parse("2006-01-02", value); err == nil {
|
||||||
return reflect.ValueOf(v)
|
return reflect.ValueOf(v)
|
||||||
}
|
}
|
||||||
|
@ -4074,7 +4102,7 @@ func Test_Ctx_ReqHeaderParser_WithSetParserDecoder(t *testing.T) {
|
||||||
|
|
||||||
nonRFCTime := ParserType{
|
nonRFCTime := ParserType{
|
||||||
Customtype: NonRFCTime{},
|
Customtype: NonRFCTime{},
|
||||||
Converter: NonRFCConverter,
|
Converter: nonRFCConverter,
|
||||||
}
|
}
|
||||||
|
|
||||||
SetParserDecoder(ParserConfig{
|
SetParserDecoder(ParserConfig{
|
||||||
|
@ -4103,7 +4131,6 @@ func Test_Ctx_ReqHeaderParser_WithSetParserDecoder(t *testing.T) {
|
||||||
c.Request().Header.Add("Body", "October")
|
c.Request().Header.Add("Body", "October")
|
||||||
|
|
||||||
utils.AssertEqual(t, nil, c.ReqHeaderParser(r))
|
utils.AssertEqual(t, nil, c.ReqHeaderParser(r))
|
||||||
fmt.Println(r.Date, "q.Date")
|
|
||||||
utils.AssertEqual(t, "CustomDateTest", r.Title)
|
utils.AssertEqual(t, "CustomDateTest", r.Title)
|
||||||
date := fmt.Sprintf("%v", r.Date)
|
date := fmt.Sprintf("%v", r.Date)
|
||||||
utils.AssertEqual(t, "{0 63753609600 <nil>}", date)
|
utils.AssertEqual(t, "{0 63753609600 <nil>}", date)
|
||||||
|
@ -4140,7 +4167,7 @@ func Test_Ctx_ReqHeaderParser_Schema(t *testing.T) {
|
||||||
|
|
||||||
c.Request().Header.Del("Name")
|
c.Request().Header.Del("Name")
|
||||||
q = new(Header1)
|
q = new(Header1)
|
||||||
utils.AssertEqual(t, "Name is empty", c.ReqHeaderParser(q).Error())
|
utils.AssertEqual(t, "failed to decode: Name is empty", c.ReqHeaderParser(q).Error())
|
||||||
|
|
||||||
c.Request().Header.Add("Name", "tom")
|
c.Request().Header.Add("Name", "tom")
|
||||||
c.Request().Header.Del("Nested.Age")
|
c.Request().Header.Del("Nested.Age")
|
||||||
|
@ -4150,7 +4177,7 @@ func Test_Ctx_ReqHeaderParser_Schema(t *testing.T) {
|
||||||
|
|
||||||
c.Request().Header.Del("Nested.Agex")
|
c.Request().Header.Del("Nested.Agex")
|
||||||
q = new(Header1)
|
q = new(Header1)
|
||||||
utils.AssertEqual(t, "Nested is empty", c.ReqHeaderParser(q).Error())
|
utils.AssertEqual(t, "failed to decode: Nested is empty", c.ReqHeaderParser(q).Error())
|
||||||
|
|
||||||
c.Request().Header.Del("Nested.Agex")
|
c.Request().Header.Del("Nested.Agex")
|
||||||
c.Request().Header.Del("Name")
|
c.Request().Header.Del("Name")
|
||||||
|
@ -4176,7 +4203,7 @@ func Test_Ctx_ReqHeaderParser_Schema(t *testing.T) {
|
||||||
c.Request().Header.Del("Nested.Age")
|
c.Request().Header.Del("Nested.Age")
|
||||||
c.Request().Header.Add("Nested.Agex", "10")
|
c.Request().Header.Add("Nested.Agex", "10")
|
||||||
h2 = new(Header2)
|
h2 = new(Header2)
|
||||||
utils.AssertEqual(t, "Nested.age is empty", c.ReqHeaderParser(h2).Error())
|
utils.AssertEqual(t, "failed to decode: Nested.age is empty", c.ReqHeaderParser(h2).Error())
|
||||||
|
|
||||||
type Node struct {
|
type Node struct {
|
||||||
Value int `reqHeader:"Val,required"`
|
Value int `reqHeader:"Val,required"`
|
||||||
|
@ -4191,7 +4218,7 @@ func Test_Ctx_ReqHeaderParser_Schema(t *testing.T) {
|
||||||
|
|
||||||
c.Request().Header.Del("Val")
|
c.Request().Header.Del("Val")
|
||||||
n = new(Node)
|
n = new(Node)
|
||||||
utils.AssertEqual(t, "Val is empty", c.ReqHeaderParser(n).Error())
|
utils.AssertEqual(t, "failed to decode: Val is empty", c.ReqHeaderParser(n).Error())
|
||||||
|
|
||||||
c.Request().Header.Add("Val", "3")
|
c.Request().Header.Add("Val", "3")
|
||||||
c.Request().Header.Del("Next.Val")
|
c.Request().Header.Del("Next.Val")
|
||||||
|
@ -4628,8 +4655,9 @@ func Test_Ctx_RepeatParserWithSameStruct(t *testing.T) {
|
||||||
|
|
||||||
var gzipJSON bytes.Buffer
|
var gzipJSON bytes.Buffer
|
||||||
w := gzip.NewWriter(&gzipJSON)
|
w := gzip.NewWriter(&gzipJSON)
|
||||||
_, _ = w.Write([]byte(`{"body_param":"body_param"}`))
|
_, _ = w.Write([]byte(`{"body_param":"body_param"}`)) //nolint:errcheck // This will never fail
|
||||||
_ = w.Close()
|
err := w.Close()
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
c.Request().Header.SetContentType(MIMEApplicationJSON)
|
c.Request().Header.SetContentType(MIMEApplicationJSON)
|
||||||
c.Request().Header.Set(HeaderContentEncoding, "gzip")
|
c.Request().Header.Set(HeaderContentEncoding, "gzip")
|
||||||
c.Request().SetBody(gzipJSON.Bytes())
|
c.Request().SetBody(gzipJSON.Bytes())
|
||||||
|
|
|
@ -1,11 +1,10 @@
|
||||||
package fiber
|
package fiber
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
jerrors "encoding/json"
|
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2/internal/schema"
|
"github.com/gofiber/fiber/v2/internal/schema"
|
||||||
"github.com/gofiber/fiber/v2/utils"
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
)
|
)
|
||||||
|
@ -36,42 +35,42 @@ func TestMultiError(t *testing.T) {
|
||||||
|
|
||||||
func TestInvalidUnmarshalError(t *testing.T) {
|
func TestInvalidUnmarshalError(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
var e *jerrors.InvalidUnmarshalError
|
var e *json.InvalidUnmarshalError
|
||||||
ok := errors.As(&InvalidUnmarshalError{}, &e)
|
ok := errors.As(&InvalidUnmarshalError{}, &e)
|
||||||
utils.AssertEqual(t, true, ok)
|
utils.AssertEqual(t, true, ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMarshalerError(t *testing.T) {
|
func TestMarshalerError(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
var e *jerrors.MarshalerError
|
var e *json.MarshalerError
|
||||||
ok := errors.As(&MarshalerError{}, &e)
|
ok := errors.As(&MarshalerError{}, &e)
|
||||||
utils.AssertEqual(t, true, ok)
|
utils.AssertEqual(t, true, ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSyntaxError(t *testing.T) {
|
func TestSyntaxError(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
var e *jerrors.SyntaxError
|
var e *json.SyntaxError
|
||||||
ok := errors.As(&SyntaxError{}, &e)
|
ok := errors.As(&SyntaxError{}, &e)
|
||||||
utils.AssertEqual(t, true, ok)
|
utils.AssertEqual(t, true, ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalTypeError(t *testing.T) {
|
func TestUnmarshalTypeError(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
var e *jerrors.UnmarshalTypeError
|
var e *json.UnmarshalTypeError
|
||||||
ok := errors.As(&UnmarshalTypeError{}, &e)
|
ok := errors.As(&UnmarshalTypeError{}, &e)
|
||||||
utils.AssertEqual(t, true, ok)
|
utils.AssertEqual(t, true, ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnsupportedTypeError(t *testing.T) {
|
func TestUnsupportedTypeError(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
var e *jerrors.UnsupportedTypeError
|
var e *json.UnsupportedTypeError
|
||||||
ok := errors.As(&UnsupportedTypeError{}, &e)
|
ok := errors.As(&UnsupportedTypeError{}, &e)
|
||||||
utils.AssertEqual(t, true, ok)
|
utils.AssertEqual(t, true, ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnsupportedValeError(t *testing.T) {
|
func TestUnsupportedValeError(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
var e *jerrors.UnsupportedValueError
|
var e *json.UnsupportedValueError
|
||||||
ok := errors.As(&UnsupportedValueError{}, &e)
|
ok := errors.As(&UnsupportedValueError{}, &e)
|
||||||
utils.AssertEqual(t, true, ok)
|
utils.AssertEqual(t, true, ok)
|
||||||
}
|
}
|
||||||
|
|
1
group.go
1
group.go
|
@ -168,7 +168,6 @@ func (grp *Group) Group(prefix string, handlers ...Handler) Router {
|
||||||
}
|
}
|
||||||
|
|
||||||
return newGrp
|
return newGrp
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Route is used to define routes with a common prefix inside the common function.
|
// Route is used to define routes with a common prefix inside the common function.
|
||||||
|
|
78
helpers.go
78
helpers.go
|
@ -20,13 +20,13 @@ import (
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2/utils"
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
|
|
||||||
"github.com/valyala/bytebufferpool"
|
"github.com/valyala/bytebufferpool"
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
/* #nosec */
|
// getTLSConfig returns a net listener's tls config
|
||||||
// getTlsConfig returns a net listener's tls config
|
func getTLSConfig(ln net.Listener) *tls.Config {
|
||||||
func getTlsConfig(ln net.Listener) *tls.Config {
|
|
||||||
// Get listener type
|
// Get listener type
|
||||||
pointer := reflect.ValueOf(ln)
|
pointer := reflect.ValueOf(ln)
|
||||||
|
|
||||||
|
@ -37,12 +37,16 @@ func getTlsConfig(ln net.Listener) *tls.Config {
|
||||||
// Get private field from value
|
// Get private field from value
|
||||||
if field := val.FieldByName("config"); field.Type() != nil {
|
if field := val.FieldByName("config"); field.Type() != nil {
|
||||||
// Copy value from pointer field (unsafe)
|
// Copy value from pointer field (unsafe)
|
||||||
newval := reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())) // #nosec G103
|
newval := reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())) //nolint:gosec // Probably the only way to extract the *tls.Config from a net.Listener. TODO: Verify there really is no easier way without using unsafe.
|
||||||
if newval.Type() != nil {
|
if newval.Type() != nil {
|
||||||
// Get element from pointer
|
// Get element from pointer
|
||||||
if elem := newval.Elem(); elem.Type() != nil {
|
if elem := newval.Elem(); elem.Type() != nil {
|
||||||
// Cast value to *tls.Config
|
// Cast value to *tls.Config
|
||||||
return elem.Interface().(*tls.Config)
|
c, ok := elem.Interface().(*tls.Config)
|
||||||
|
if !ok {
|
||||||
|
panic(fmt.Errorf("failed to type-assert to *tls.Config"))
|
||||||
|
}
|
||||||
|
return c
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -53,19 +57,21 @@ func getTlsConfig(ln net.Listener) *tls.Config {
|
||||||
}
|
}
|
||||||
|
|
||||||
// readContent opens a named file and read content from it
|
// readContent opens a named file and read content from it
|
||||||
func readContent(rf io.ReaderFrom, name string) (n int64, err error) {
|
func readContent(rf io.ReaderFrom, name string) (int64, error) {
|
||||||
// Read file
|
// Read file
|
||||||
f, err := os.Open(filepath.Clean(name))
|
f, err := os.Open(filepath.Clean(name))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, fmt.Errorf("failed to open: %w", err)
|
||||||
}
|
}
|
||||||
// #nosec G307
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err = f.Close(); err != nil {
|
if err = f.Close(); err != nil {
|
||||||
log.Printf("Error closing file: %s\n", err)
|
log.Printf("Error closing file: %s\n", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return rf.ReadFrom(f)
|
if n, err := rf.ReadFrom(f); err != nil {
|
||||||
|
return n, fmt.Errorf("failed to read: %w", err)
|
||||||
|
}
|
||||||
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// quoteString escape special characters in a given string
|
// quoteString escape special characters in a given string
|
||||||
|
@ -78,7 +84,8 @@ func (app *App) quoteString(raw string) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan stack if other methods match the request
|
// Scan stack if other methods match the request
|
||||||
func (app *App) methodExist(ctx *Ctx) (exist bool) {
|
func (app *App) methodExist(ctx *Ctx) bool {
|
||||||
|
var exists bool
|
||||||
methods := app.config.RequestMethods
|
methods := app.config.RequestMethods
|
||||||
for i := 0; i < len(methods); i++ {
|
for i := 0; i < len(methods); i++ {
|
||||||
// Skip original method
|
// Skip original method
|
||||||
|
@ -108,7 +115,7 @@ func (app *App) methodExist(ctx *Ctx) (exist bool) {
|
||||||
// No match, next route
|
// No match, next route
|
||||||
if match {
|
if match {
|
||||||
// We matched
|
// We matched
|
||||||
exist = true
|
exists = true
|
||||||
// Add method to Allow header
|
// Add method to Allow header
|
||||||
ctx.Append(HeaderAllow, methods[i])
|
ctx.Append(HeaderAllow, methods[i])
|
||||||
// Break stack loop
|
// Break stack loop
|
||||||
|
@ -116,7 +123,7 @@ func (app *App) methodExist(ctx *Ctx) (exist bool) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return exists
|
||||||
}
|
}
|
||||||
|
|
||||||
// uniqueRouteStack drop all not unique routes from the slice
|
// uniqueRouteStack drop all not unique routes from the slice
|
||||||
|
@ -146,7 +153,7 @@ func defaultString(value string, defaultValue []string) string {
|
||||||
const normalizedHeaderETag = "Etag"
|
const normalizedHeaderETag = "Etag"
|
||||||
|
|
||||||
// Generate and set ETag header to response
|
// Generate and set ETag header to response
|
||||||
func setETag(c *Ctx, weak bool) {
|
func setETag(c *Ctx, weak bool) { //nolint: revive // Accepting a bool param is fine here
|
||||||
// Don't generate ETags for invalid responses
|
// Don't generate ETags for invalid responses
|
||||||
if c.fasthttp.Response.StatusCode() != StatusOK {
|
if c.fasthttp.Response.StatusCode() != StatusOK {
|
||||||
return
|
return
|
||||||
|
@ -160,7 +167,8 @@ func setETag(c *Ctx, weak bool) {
|
||||||
clientEtag := c.Get(HeaderIfNoneMatch)
|
clientEtag := c.Get(HeaderIfNoneMatch)
|
||||||
|
|
||||||
// Generate ETag for response
|
// Generate ETag for response
|
||||||
crc32q := crc32.MakeTable(0xD5828281)
|
const pol = 0xD5828281
|
||||||
|
crc32q := crc32.MakeTable(pol)
|
||||||
etag := fmt.Sprintf("\"%d-%v\"", len(body), crc32.Checksum(body, crc32q))
|
etag := fmt.Sprintf("\"%d-%v\"", len(body), crc32.Checksum(body, crc32q))
|
||||||
|
|
||||||
// Enable weak tag
|
// Enable weak tag
|
||||||
|
@ -173,7 +181,9 @@ func setETag(c *Ctx, weak bool) {
|
||||||
// Check if server's ETag is weak
|
// Check if server's ETag is weak
|
||||||
if clientEtag[2:] == etag || clientEtag[2:] == etag[2:] {
|
if clientEtag[2:] == etag || clientEtag[2:] == etag[2:] {
|
||||||
// W/1 == 1 || W/1 == W/1
|
// W/1 == 1 || W/1 == W/1
|
||||||
_ = c.SendStatus(StatusNotModified)
|
if err := c.SendStatus(StatusNotModified); err != nil {
|
||||||
|
log.Printf("setETag: failed to SendStatus: %v\n", err)
|
||||||
|
}
|
||||||
c.fasthttp.ResetBody()
|
c.fasthttp.ResetBody()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -183,7 +193,9 @@ func setETag(c *Ctx, weak bool) {
|
||||||
}
|
}
|
||||||
if strings.Contains(clientEtag, etag) {
|
if strings.Contains(clientEtag, etag) {
|
||||||
// 1 == 1
|
// 1 == 1
|
||||||
_ = c.SendStatus(StatusNotModified)
|
if err := c.SendStatus(StatusNotModified); err != nil {
|
||||||
|
log.Printf("setETag: failed to SendStatus: %v\n", err)
|
||||||
|
}
|
||||||
c.fasthttp.ResetBody()
|
c.fasthttp.ResetBody()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -239,7 +251,7 @@ func getOffer(header string, offers ...string) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func matchEtag(s string, etag string) bool {
|
func matchEtag(s, etag string) bool {
|
||||||
if s == etag || s == "W/"+etag || "W/"+s == etag {
|
if s == etag || s == "W/"+etag || "W/"+s == etag {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -254,12 +266,12 @@ func (app *App) isEtagStale(etag string, noneMatchBytes []byte) bool {
|
||||||
// https://github.com/jshttp/fresh/blob/10e0471669dbbfbfd8de65bc6efac2ddd0bfa057/index.js#L110
|
// https://github.com/jshttp/fresh/blob/10e0471669dbbfbfd8de65bc6efac2ddd0bfa057/index.js#L110
|
||||||
for i := range noneMatchBytes {
|
for i := range noneMatchBytes {
|
||||||
switch noneMatchBytes[i] {
|
switch noneMatchBytes[i] {
|
||||||
case 0x20:
|
case 0x20: //nolint:gomnd // This is a space (" ")
|
||||||
if start == end {
|
if start == end {
|
||||||
start = i + 1
|
start = i + 1
|
||||||
end = i + 1
|
end = i + 1
|
||||||
}
|
}
|
||||||
case 0x2c:
|
case 0x2c: //nolint:gomnd // This is a comma (",")
|
||||||
if matchEtag(app.getString(noneMatchBytes[start:end]), etag) {
|
if matchEtag(app.getString(noneMatchBytes[start:end]), etag) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -273,7 +285,7 @@ func (app *App) isEtagStale(etag string, noneMatchBytes []byte) bool {
|
||||||
return !matchEtag(app.getString(noneMatchBytes[start:end]), etag)
|
return !matchEtag(app.getString(noneMatchBytes[start:end]), etag)
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseAddr(raw string) (host, port string) {
|
func parseAddr(raw string) (string, string) { //nolint:revive // Returns (host, port)
|
||||||
if i := strings.LastIndex(raw, ":"); i != -1 {
|
if i := strings.LastIndex(raw, ":"); i != -1 {
|
||||||
return raw[:i], raw[i+1:]
|
return raw[:i], raw[i+1:]
|
||||||
}
|
}
|
||||||
|
@ -313,21 +325,21 @@ type testConn struct {
|
||||||
w bytes.Buffer
|
w bytes.Buffer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *testConn) Read(b []byte) (int, error) { return c.r.Read(b) }
|
func (c *testConn) Read(b []byte) (int, error) { return c.r.Read(b) } //nolint:wrapcheck // This must not be wrapped
|
||||||
func (c *testConn) Write(b []byte) (int, error) { return c.w.Write(b) }
|
func (c *testConn) Write(b []byte) (int, error) { return c.w.Write(b) } //nolint:wrapcheck // This must not be wrapped
|
||||||
func (c *testConn) Close() error { return nil }
|
func (*testConn) Close() error { return nil }
|
||||||
|
|
||||||
func (c *testConn) LocalAddr() net.Addr { return &net.TCPAddr{Port: 0, Zone: "", IP: net.IPv4zero} }
|
func (*testConn) LocalAddr() net.Addr { return &net.TCPAddr{Port: 0, Zone: "", IP: net.IPv4zero} }
|
||||||
func (c *testConn) RemoteAddr() net.Addr { return &net.TCPAddr{Port: 0, Zone: "", IP: net.IPv4zero} }
|
func (*testConn) RemoteAddr() net.Addr { return &net.TCPAddr{Port: 0, Zone: "", IP: net.IPv4zero} }
|
||||||
func (c *testConn) SetDeadline(_ time.Time) error { return nil }
|
func (*testConn) SetDeadline(_ time.Time) error { return nil }
|
||||||
func (c *testConn) SetReadDeadline(_ time.Time) error { return nil }
|
func (*testConn) SetReadDeadline(_ time.Time) error { return nil }
|
||||||
func (c *testConn) SetWriteDeadline(_ time.Time) error { return nil }
|
func (*testConn) SetWriteDeadline(_ time.Time) error { return nil }
|
||||||
|
|
||||||
var getStringImmutable = func(b []byte) string {
|
func getStringImmutable(b []byte) string {
|
||||||
return string(b)
|
return string(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
var getBytesImmutable = func(s string) (b []byte) {
|
func getBytesImmutable(s string) []byte {
|
||||||
return []byte(s)
|
return []byte(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -335,6 +347,7 @@ var getBytesImmutable = func(s string) (b []byte) {
|
||||||
func (app *App) methodInt(s string) int {
|
func (app *App) methodInt(s string) int {
|
||||||
// For better performance
|
// For better performance
|
||||||
if len(app.configured.RequestMethods) == 0 {
|
if len(app.configured.RequestMethods) == 0 {
|
||||||
|
//nolint:gomnd // TODO: Use iota instead
|
||||||
switch s {
|
switch s {
|
||||||
case MethodGet:
|
case MethodGet:
|
||||||
return 0
|
return 0
|
||||||
|
@ -391,8 +404,7 @@ func IsMethodIdempotent(m string) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
switch m {
|
switch m {
|
||||||
case MethodPut,
|
case MethodPut, MethodDelete:
|
||||||
MethodDelete:
|
|
||||||
return true
|
return true
|
||||||
default:
|
default:
|
||||||
return false
|
return false
|
||||||
|
@ -714,7 +726,7 @@ const (
|
||||||
ConstraintBool = "bool"
|
ConstraintBool = "bool"
|
||||||
ConstraintFloat = "float"
|
ConstraintFloat = "float"
|
||||||
ConstraintAlpha = "alpha"
|
ConstraintAlpha = "alpha"
|
||||||
ConstraintGuid = "guid"
|
ConstraintGuid = "guid" //nolint:revive,stylecheck // TODO: Rename to "ConstraintGUID" in v3
|
||||||
ConstraintMinLen = "minLen"
|
ConstraintMinLen = "minLen"
|
||||||
ConstraintMaxLen = "maxLen"
|
ConstraintMaxLen = "maxLen"
|
||||||
ConstraintLen = "len"
|
ConstraintLen = "len"
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2/utils"
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
30
hooks.go
30
hooks.go
|
@ -1,14 +1,20 @@
|
||||||
package fiber
|
package fiber
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
)
|
||||||
|
|
||||||
// OnRouteHandler Handlers define a function to create hooks for Fiber.
|
// OnRouteHandler Handlers define a function to create hooks for Fiber.
|
||||||
type OnRouteHandler = func(Route) error
|
type (
|
||||||
type OnNameHandler = OnRouteHandler
|
OnRouteHandler = func(Route) error
|
||||||
type OnGroupHandler = func(Group) error
|
OnNameHandler = OnRouteHandler
|
||||||
type OnGroupNameHandler = OnGroupHandler
|
OnGroupHandler = func(Group) error
|
||||||
type OnListenHandler = func() error
|
OnGroupNameHandler = OnGroupHandler
|
||||||
type OnShutdownHandler = OnListenHandler
|
OnListenHandler = func() error
|
||||||
type OnForkHandler = func(int) error
|
OnShutdownHandler = OnListenHandler
|
||||||
type OnMountHandler = func(*App) error
|
OnForkHandler = func(int) error
|
||||||
|
OnMountHandler = func(*App) error
|
||||||
|
)
|
||||||
|
|
||||||
// Hooks is a struct to use it with App.
|
// Hooks is a struct to use it with App.
|
||||||
type Hooks struct {
|
type Hooks struct {
|
||||||
|
@ -180,13 +186,17 @@ func (h *Hooks) executeOnListenHooks() error {
|
||||||
|
|
||||||
func (h *Hooks) executeOnShutdownHooks() {
|
func (h *Hooks) executeOnShutdownHooks() {
|
||||||
for _, v := range h.onShutdown {
|
for _, v := range h.onShutdown {
|
||||||
_ = v()
|
if err := v(); err != nil {
|
||||||
|
log.Printf("failed to call shutdown hook: %v\n", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Hooks) executeOnForkHooks(pid int) {
|
func (h *Hooks) executeOnForkHooks(pid int) {
|
||||||
for _, v := range h.onFork {
|
for _, v := range h.onFork {
|
||||||
_ = v(pid)
|
if err := v(pid); err != nil {
|
||||||
|
log.Printf("failed to call fork hook: %v\n", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,10 +7,11 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2/utils"
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
|
|
||||||
"github.com/valyala/bytebufferpool"
|
"github.com/valyala/bytebufferpool"
|
||||||
)
|
)
|
||||||
|
|
||||||
var testSimpleHandler = func(c *Ctx) error {
|
func testSimpleHandler(c *Ctx) error {
|
||||||
return c.SendString("simple")
|
return c.SendString("simple")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ var defaultPool = sync.Pool{
|
||||||
|
|
||||||
// AcquireDict acquire new dict.
|
// AcquireDict acquire new dict.
|
||||||
func AcquireDict() *Dict {
|
func AcquireDict() *Dict {
|
||||||
return defaultPool.Get().(*Dict) // nolint:forcetypeassert
|
return defaultPool.Get().(*Dict)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReleaseDict release dict.
|
// ReleaseDict release dict.
|
||||||
|
|
|
@ -366,7 +366,7 @@ func HostDev(combineWith ...string) string {
|
||||||
// getSysctrlEnv sets LC_ALL=C in a list of env vars for use when running
|
// getSysctrlEnv sets LC_ALL=C in a list of env vars for use when running
|
||||||
// sysctl commands (see DoSysctrl).
|
// sysctl commands (see DoSysctrl).
|
||||||
func getSysctrlEnv(env []string) []string {
|
func getSysctrlEnv(env []string) []string {
|
||||||
foundLC := false
|
var foundLC bool
|
||||||
for i, line := range env {
|
for i, line := range env {
|
||||||
if strings.HasPrefix(line, "LC_ALL") {
|
if strings.HasPrefix(line, "LC_ALL") {
|
||||||
env[i] = "LC_ALL=C"
|
env[i] = "LC_ALL=C"
|
||||||
|
|
|
@ -6,8 +6,7 @@ import (
|
||||||
"github.com/gofiber/fiber/v2/internal/gopsutil/common"
|
"github.com/gofiber/fiber/v2/internal/gopsutil/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
//lint:ignore U1000 we need this elsewhere
|
var invoke common.Invoker = common.Invoke{} //nolint:unused // We use this only for some OS'es
|
||||||
var invoke common.Invoker = common.Invoke{} //nolint:all
|
|
||||||
|
|
||||||
// Memory usage statistics. Total, Available and Used contain numbers of bytes
|
// Memory usage statistics. Total, Available and Used contain numbers of bytes
|
||||||
// for human consumption.
|
// for human consumption.
|
||||||
|
|
|
@ -86,7 +86,6 @@ func SwapMemory() (*SwapMemoryStat, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Constants from vm/vm_param.h
|
// Constants from vm/vm_param.h
|
||||||
// nolint: golint
|
|
||||||
const (
|
const (
|
||||||
XSWDEV_VERSION11 = 1
|
XSWDEV_VERSION11 = 1
|
||||||
XSWDEV_VERSION = 2
|
XSWDEV_VERSION = 2
|
||||||
|
|
|
@ -57,10 +57,12 @@ func fillFromMeminfoWithContext(ctx context.Context) (*VirtualMemoryStat, *Virtu
|
||||||
lines, _ := common.ReadLines(filename)
|
lines, _ := common.ReadLines(filename)
|
||||||
|
|
||||||
// flag if MemAvailable is in /proc/meminfo (kernel 3.14+)
|
// flag if MemAvailable is in /proc/meminfo (kernel 3.14+)
|
||||||
memavail := false
|
var (
|
||||||
activeFile := false // "Active(file)" not available: 2.6.28 / Dec 2008
|
memavail bool
|
||||||
inactiveFile := false // "Inactive(file)" not available: 2.6.28 / Dec 2008
|
activeFile bool // "Active(file)" not available: 2.6.28 / Dec 2008
|
||||||
sReclaimable := false // "SReclaimable:" not available: 2.6.19 / Nov 2006
|
inactiveFile bool // "Inactive(file)" not available: 2.6.28 / Dec 2008
|
||||||
|
sReclaimable bool // "SReclaimable:" not available: 2.6.19 / Nov 2006
|
||||||
|
)
|
||||||
|
|
||||||
ret := &VirtualMemoryStat{}
|
ret := &VirtualMemoryStat{}
|
||||||
retEx := &VirtualMemoryExStat{}
|
retEx := &VirtualMemoryExStat{}
|
||||||
|
|
|
@ -168,7 +168,7 @@ func rwArray(dst jsWriter, src *Reader) (n int, err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
comma := false
|
var comma bool
|
||||||
for i := uint32(0); i < sz; i++ {
|
for i := uint32(0); i < sz; i++ {
|
||||||
if comma {
|
if comma {
|
||||||
err = dst.WriteByte(',')
|
err = dst.WriteByte(',')
|
||||||
|
|
|
@ -163,7 +163,7 @@ func (c *cache) createField(field reflect.StructField, parentAlias string) *fiel
|
||||||
}
|
}
|
||||||
// Check if the type is supported and don't cache it if not.
|
// Check if the type is supported and don't cache it if not.
|
||||||
// First let's get the basic type.
|
// First let's get the basic type.
|
||||||
isSlice, isStruct := false, false
|
var isSlice, isStruct bool
|
||||||
ft := field.Type
|
ft := field.Type
|
||||||
m := isTextUnmarshaler(reflect.Zero(ft))
|
m := isTextUnmarshaler(reflect.Zero(ft))
|
||||||
if ft.Kind() == reflect.Ptr {
|
if ft.Kind() == reflect.Ptr {
|
||||||
|
|
|
@ -149,13 +149,13 @@ func Benchmark_Storage_Memory(b *testing.B) {
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
d.Set(key, value, ttl)
|
_ = d.Set(key, value, ttl)
|
||||||
}
|
}
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
_, _ = d.Get(key)
|
_, _ = d.Get(key)
|
||||||
}
|
}
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
d.Delete(key)
|
_ = d.Delete(key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -156,7 +156,6 @@ func (e *Engine) Load() error {
|
||||||
name = strings.TrimSuffix(name, e.extension)
|
name = strings.TrimSuffix(name, e.extension)
|
||||||
// name = strings.Replace(name, e.extension, "", -1)
|
// name = strings.Replace(name, e.extension, "", -1)
|
||||||
// Read the file
|
// Read the file
|
||||||
// #gosec G304
|
|
||||||
buf, err := utils.ReadFile(path, e.fileSystem)
|
buf, err := utils.ReadFile(path, e.fileSystem)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -21,7 +21,6 @@ func Walk(fs http.FileSystem, root string, walkFn filepath.WalkFunc) error {
|
||||||
return walk(fs, root, info, walkFn)
|
return walk(fs, root, info, walkFn)
|
||||||
}
|
}
|
||||||
|
|
||||||
// #nosec G304
|
|
||||||
// ReadFile returns the raw content of a file
|
// ReadFile returns the raw content of a file
|
||||||
func ReadFile(path string, fs http.FileSystem) ([]byte, error) {
|
func ReadFile(path string, fs http.FileSystem) ([]byte, error) {
|
||||||
if fs != nil {
|
if fs != nil {
|
||||||
|
|
62
listen.go
62
listen.go
|
@ -9,6 +9,7 @@ import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
@ -31,7 +32,7 @@ func (app *App) Listener(ln net.Listener) error {
|
||||||
|
|
||||||
// Print startup message
|
// Print startup message
|
||||||
if !app.config.DisableStartupMessage {
|
if !app.config.DisableStartupMessage {
|
||||||
app.startupMessage(ln.Addr().String(), getTlsConfig(ln) != nil, "")
|
app.startupMessage(ln.Addr().String(), getTLSConfig(ln) != nil, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Print routes
|
// Print routes
|
||||||
|
@ -41,7 +42,7 @@ func (app *App) Listener(ln net.Listener) error {
|
||||||
|
|
||||||
// Prefork is not supported for custom listeners
|
// Prefork is not supported for custom listeners
|
||||||
if app.config.Prefork {
|
if app.config.Prefork {
|
||||||
fmt.Println("[Warning] Prefork isn't supported for custom listeners.")
|
log.Printf("[Warning] Prefork isn't supported for custom listeners.\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start listening
|
// Start listening
|
||||||
|
@ -61,7 +62,7 @@ func (app *App) Listen(addr string) error {
|
||||||
// Setup listener
|
// Setup listener
|
||||||
ln, err := net.Listen(app.config.Network, addr)
|
ln, err := net.Listen(app.config.Network, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to listen: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepare the server for the start
|
// prepare the server for the start
|
||||||
|
@ -94,7 +95,7 @@ func (app *App) ListenTLS(addr, certFile, keyFile string) error {
|
||||||
// Set TLS config with handler
|
// Set TLS config with handler
|
||||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %s", certFile, keyFile, err)
|
return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %w", certFile, keyFile, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tlsHandler := &TLSHandler{}
|
tlsHandler := &TLSHandler{}
|
||||||
|
@ -115,7 +116,7 @@ func (app *App) ListenTLS(addr, certFile, keyFile string) error {
|
||||||
ln, err := net.Listen(app.config.Network, addr)
|
ln, err := net.Listen(app.config.Network, addr)
|
||||||
ln = tls.NewListener(ln, config)
|
ln = tls.NewListener(ln, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to listen: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepare the server for the start
|
// prepare the server for the start
|
||||||
|
@ -150,12 +151,12 @@ func (app *App) ListenMutualTLS(addr, certFile, keyFile, clientCertFile string)
|
||||||
|
|
||||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %s", certFile, keyFile, err)
|
return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %w", certFile, keyFile, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
clientCACert, err := os.ReadFile(filepath.Clean(clientCertFile))
|
clientCACert, err := os.ReadFile(filepath.Clean(clientCertFile))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to read file: %w", err)
|
||||||
}
|
}
|
||||||
clientCertPool := x509.NewCertPool()
|
clientCertPool := x509.NewCertPool()
|
||||||
clientCertPool.AppendCertsFromPEM(clientCACert)
|
clientCertPool.AppendCertsFromPEM(clientCACert)
|
||||||
|
@ -179,7 +180,7 @@ func (app *App) ListenMutualTLS(addr, certFile, keyFile, clientCertFile string)
|
||||||
// Setup listener
|
// Setup listener
|
||||||
ln, err := tls.Listen(app.config.Network, addr, config)
|
ln, err := tls.Listen(app.config.Network, addr, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to listen: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepare the server for the start
|
// prepare the server for the start
|
||||||
|
@ -203,7 +204,7 @@ func (app *App) ListenMutualTLS(addr, certFile, keyFile, clientCertFile string)
|
||||||
}
|
}
|
||||||
|
|
||||||
// startupMessage prepares the startup message with the handler number, port, address and other information
|
// startupMessage prepares the startup message with the handler number, port, address and other information
|
||||||
func (app *App) startupMessage(addr string, tls bool, pids string) {
|
func (app *App) startupMessage(addr string, tls bool, pids string) { //nolint: revive // Accepting a bool param is fine here
|
||||||
// ignore child processes
|
// ignore child processes
|
||||||
if IsChild() {
|
if IsChild() {
|
||||||
return
|
return
|
||||||
|
@ -227,7 +228,8 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
center := func(s string, width int) string {
|
center := func(s string, width int) string {
|
||||||
pad := strconv.Itoa((width - len(s)) / 2)
|
const padDiv = 2
|
||||||
|
pad := strconv.Itoa((width - len(s)) / padDiv)
|
||||||
str := fmt.Sprintf("%"+pad+"s", " ")
|
str := fmt.Sprintf("%"+pad+"s", " ")
|
||||||
str += s
|
str += s
|
||||||
str += fmt.Sprintf("%"+pad+"s", " ")
|
str += fmt.Sprintf("%"+pad+"s", " ")
|
||||||
|
@ -238,7 +240,8 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
centerValue := func(s string, width int) string {
|
centerValue := func(s string, width int) string {
|
||||||
pad := strconv.Itoa((width - runewidth.StringWidth(s)) / 2)
|
const padDiv = 2
|
||||||
|
pad := strconv.Itoa((width - runewidth.StringWidth(s)) / padDiv)
|
||||||
str := fmt.Sprintf("%"+pad+"s", " ")
|
str := fmt.Sprintf("%"+pad+"s", " ")
|
||||||
str += fmt.Sprintf("%s%s%s", colors.Cyan, s, colors.Black)
|
str += fmt.Sprintf("%s%s%s", colors.Cyan, s, colors.Black)
|
||||||
str += fmt.Sprintf("%"+pad+"s", " ")
|
str += fmt.Sprintf("%"+pad+"s", " ")
|
||||||
|
@ -249,13 +252,13 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
|
||||||
return str
|
return str
|
||||||
}
|
}
|
||||||
|
|
||||||
pad := func(s string, width int) (str string) {
|
pad := func(s string, width int) string {
|
||||||
toAdd := width - len(s)
|
toAdd := width - len(s)
|
||||||
str += s
|
str := s
|
||||||
for i := 0; i < toAdd; i++ {
|
for i := 0; i < toAdd; i++ {
|
||||||
str += " "
|
str += " "
|
||||||
}
|
}
|
||||||
return
|
return str
|
||||||
}
|
}
|
||||||
|
|
||||||
host, port := parseAddr(addr)
|
host, port := parseAddr(addr)
|
||||||
|
@ -267,9 +270,9 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
scheme := "http"
|
scheme := schemeHTTP
|
||||||
if tls {
|
if tls {
|
||||||
scheme = "https"
|
scheme = schemeHTTPS
|
||||||
}
|
}
|
||||||
|
|
||||||
isPrefork := "Disabled"
|
isPrefork := "Disabled"
|
||||||
|
@ -282,19 +285,18 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
|
||||||
procs = "1"
|
procs = "1"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const lineLen = 49
|
||||||
mainLogo := colors.Black + " ┌───────────────────────────────────────────────────┐\n"
|
mainLogo := colors.Black + " ┌───────────────────────────────────────────────────┐\n"
|
||||||
if app.config.AppName != "" {
|
if app.config.AppName != "" {
|
||||||
mainLogo += " │ " + centerValue(app.config.AppName, 49) + " │\n"
|
mainLogo += " │ " + centerValue(app.config.AppName, lineLen) + " │\n"
|
||||||
}
|
}
|
||||||
mainLogo += " │ " + centerValue("Fiber v"+Version, 49) + " │\n"
|
mainLogo += " │ " + centerValue("Fiber v"+Version, lineLen) + " │\n"
|
||||||
|
|
||||||
if host == "0.0.0.0" {
|
if host == "0.0.0.0" {
|
||||||
mainLogo +=
|
mainLogo += " │ " + center(fmt.Sprintf("%s://127.0.0.1:%s", scheme, port), lineLen) + " │\n" +
|
||||||
" │ " + center(fmt.Sprintf("%s://127.0.0.1:%s", scheme, port), 49) + " │\n" +
|
" │ " + center(fmt.Sprintf("(bound on host 0.0.0.0 and port %s)", port), lineLen) + " │\n"
|
||||||
" │ " + center(fmt.Sprintf("(bound on host 0.0.0.0 and port %s)", port), 49) + " │\n"
|
|
||||||
} else {
|
} else {
|
||||||
mainLogo +=
|
mainLogo += " │ " + center(fmt.Sprintf("%s://%s:%s", scheme, host, port), lineLen) + " │\n"
|
||||||
" │ " + center(fmt.Sprintf("%s://%s:%s", scheme, host, port), 49) + " │\n"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mainLogo += fmt.Sprintf(
|
mainLogo += fmt.Sprintf(
|
||||||
|
@ -303,8 +305,8 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
|
||||||
" │ Prefork .%s PID ....%s │\n"+
|
" │ Prefork .%s PID ....%s │\n"+
|
||||||
" └───────────────────────────────────────────────────┘"+
|
" └───────────────────────────────────────────────────┘"+
|
||||||
colors.Reset,
|
colors.Reset,
|
||||||
value(strconv.Itoa(int(app.handlersCount)), 14), value(procs, 12),
|
value(strconv.Itoa(int(app.handlersCount)), 14), value(procs, 12), //nolint:gomnd // Using random padding lengths is fine here
|
||||||
value(isPrefork, 14), value(strconv.Itoa(os.Getpid()), 14),
|
value(isPrefork, 14), value(strconv.Itoa(os.Getpid()), 14), //nolint:gomnd // Using random padding lengths is fine here
|
||||||
)
|
)
|
||||||
|
|
||||||
var childPidsLogo string
|
var childPidsLogo string
|
||||||
|
@ -329,19 +331,21 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
|
||||||
thisLine := "Child PIDs ... "
|
thisLine := "Child PIDs ... "
|
||||||
var itemsOnThisLine []string
|
var itemsOnThisLine []string
|
||||||
|
|
||||||
|
const maxLineLen = 49
|
||||||
|
|
||||||
addLine := func() {
|
addLine := func() {
|
||||||
lines = append(lines,
|
lines = append(lines,
|
||||||
fmt.Sprintf(
|
fmt.Sprintf(
|
||||||
newLine,
|
newLine,
|
||||||
colors.Black,
|
colors.Black,
|
||||||
thisLine+colors.Cyan+pad(strings.Join(itemsOnThisLine, ", "), 49-len(thisLine)),
|
thisLine+colors.Cyan+pad(strings.Join(itemsOnThisLine, ", "), maxLineLen-len(thisLine)),
|
||||||
colors.Black,
|
colors.Black,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, pid := range pidSlice {
|
for _, pid := range pidSlice {
|
||||||
if len(thisLine+strings.Join(append(itemsOnThisLine, pid), ", ")) > 49 {
|
if len(thisLine+strings.Join(append(itemsOnThisLine, pid), ", ")) > maxLineLen {
|
||||||
addLine()
|
addLine()
|
||||||
thisLine = ""
|
thisLine = ""
|
||||||
itemsOnThisLine = []string{pid}
|
itemsOnThisLine = []string{pid}
|
||||||
|
@ -415,7 +419,7 @@ func (app *App) printRoutesMessage() {
|
||||||
var routes []RouteMessage
|
var routes []RouteMessage
|
||||||
for _, routeStack := range app.stack {
|
for _, routeStack := range app.stack {
|
||||||
for _, route := range routeStack {
|
for _, route := range routeStack {
|
||||||
var newRoute = RouteMessage{}
|
var newRoute RouteMessage
|
||||||
newRoute.name = route.Name
|
newRoute.name = route.Name
|
||||||
newRoute.method = route.Method
|
newRoute.method = route.Method
|
||||||
newRoute.path = route.Path
|
newRoute.path = route.Path
|
||||||
|
@ -443,5 +447,5 @@ func (app *App) printRoutesMessage() {
|
||||||
_, _ = fmt.Fprintf(w, "%s%s\t%s| %s%s\t%s| %s%s\t%s| %s%s\n", colors.Blue, route.method, colors.White, colors.Green, route.path, colors.White, colors.Cyan, route.name, colors.White, colors.Yellow, route.handlers)
|
_, _ = fmt.Fprintf(w, "%s%s\t%s| %s%s\t%s| %s%s\t%s| %s%s\n", colors.Blue, route.method, colors.White, colors.Green, route.path, colors.White, colors.Cyan, route.name, colors.White, colors.Yellow, route.handlers)
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = w.Flush()
|
_ = w.Flush() //nolint:errcheck // It is fine to ignore the error here
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,6 @@ package fiber
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
|
@ -17,6 +16,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2/utils"
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp/fasthttputil"
|
"github.com/valyala/fasthttp/fasthttputil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -125,8 +125,10 @@ func Test_App_Listener_TLS_Listener(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
}
|
}
|
||||||
|
//nolint:gosec // We're in a test so using old ciphers is fine
|
||||||
config := &tls.Config{Certificates: []tls.Certificate{cer}}
|
config := &tls.Config{Certificates: []tls.Certificate{cer}}
|
||||||
|
|
||||||
|
//nolint:gosec // We're in a test so listening on all interfaces is fine
|
||||||
ln, err := tls.Listen(NetworkTCP4, ":0", config)
|
ln, err := tls.Listen(NetworkTCP4, ":0", config)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
|
@ -182,7 +184,6 @@ func Test_App_Master_Process_Show_Startup_Message(t *testing.T) {
|
||||||
New(Config{Prefork: true}).
|
New(Config{Prefork: true}).
|
||||||
startupMessage(":3000", true, strings.Repeat(",11111,22222,33333,44444,55555,60000", 10))
|
startupMessage(":3000", true, strings.Repeat(",11111,22222,33333,44444,55555,60000", 10))
|
||||||
})
|
})
|
||||||
fmt.Println(startupMessage)
|
|
||||||
utils.AssertEqual(t, true, strings.Contains(startupMessage, "https://127.0.0.1:3000"))
|
utils.AssertEqual(t, true, strings.Contains(startupMessage, "https://127.0.0.1:3000"))
|
||||||
utils.AssertEqual(t, true, strings.Contains(startupMessage, "(bound on host 0.0.0.0 and port 3000)"))
|
utils.AssertEqual(t, true, strings.Contains(startupMessage, "(bound on host 0.0.0.0 and port 3000)"))
|
||||||
utils.AssertEqual(t, true, strings.Contains(startupMessage, "Child PIDs"))
|
utils.AssertEqual(t, true, strings.Contains(startupMessage, "Child PIDs"))
|
||||||
|
@ -196,7 +197,6 @@ func Test_App_Master_Process_Show_Startup_MessageWithAppName(t *testing.T) {
|
||||||
startupMessage := captureOutput(func() {
|
startupMessage := captureOutput(func() {
|
||||||
app.startupMessage(":3000", true, strings.Repeat(",11111,22222,33333,44444,55555,60000", 10))
|
app.startupMessage(":3000", true, strings.Repeat(",11111,22222,33333,44444,55555,60000", 10))
|
||||||
})
|
})
|
||||||
fmt.Println(startupMessage)
|
|
||||||
utils.AssertEqual(t, "Test App v1.0.1", app.Config().AppName)
|
utils.AssertEqual(t, "Test App v1.0.1", app.Config().AppName)
|
||||||
utils.AssertEqual(t, true, strings.Contains(startupMessage, app.Config().AppName))
|
utils.AssertEqual(t, true, strings.Contains(startupMessage, app.Config().AppName))
|
||||||
}
|
}
|
||||||
|
@ -208,7 +208,6 @@ func Test_App_Master_Process_Show_Startup_MessageWithAppNameNonAscii(t *testing.
|
||||||
startupMessage := captureOutput(func() {
|
startupMessage := captureOutput(func() {
|
||||||
app.startupMessage(":3000", false, "")
|
app.startupMessage(":3000", false, "")
|
||||||
})
|
})
|
||||||
fmt.Println(startupMessage)
|
|
||||||
utils.AssertEqual(t, true, strings.Contains(startupMessage, "│ Serveur de vérification des données │"))
|
utils.AssertEqual(t, true, strings.Contains(startupMessage, "│ Serveur de vérification des données │"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -219,8 +218,7 @@ func Test_App_print_Route(t *testing.T) {
|
||||||
printRoutesMessage := captureOutput(func() {
|
printRoutesMessage := captureOutput(func() {
|
||||||
app.printRoutesMessage()
|
app.printRoutesMessage()
|
||||||
})
|
})
|
||||||
fmt.Println(printRoutesMessage)
|
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, MethodGet))
|
||||||
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "GET"))
|
|
||||||
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "/"))
|
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "/"))
|
||||||
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "emptyHandler"))
|
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "emptyHandler"))
|
||||||
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "routeName"))
|
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "routeName"))
|
||||||
|
@ -240,11 +238,11 @@ func Test_App_print_Route_with_group(t *testing.T) {
|
||||||
app.printRoutesMessage()
|
app.printRoutesMessage()
|
||||||
})
|
})
|
||||||
|
|
||||||
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "GET"))
|
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, MethodGet))
|
||||||
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "/"))
|
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "/"))
|
||||||
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "emptyHandler"))
|
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "emptyHandler"))
|
||||||
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "/v1/test"))
|
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "/v1/test"))
|
||||||
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "POST"))
|
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, MethodPost))
|
||||||
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "/v1/test/fiber"))
|
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "/v1/test/fiber"))
|
||||||
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "PUT"))
|
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "PUT"))
|
||||||
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "/v1/test/fiber/*"))
|
utils.AssertEqual(t, true, strings.Contains(printRoutesMessage, "/v1/test/fiber/*"))
|
||||||
|
|
|
@ -1,15 +1,15 @@
|
||||||
package basicauth
|
package basicauth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
b64 "encoding/base64"
|
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/gofiber/fiber/v2/utils"
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ func Test_BasicAuth_Next(t *testing.T) {
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
@ -39,6 +39,7 @@ func Test_Middleware_BasicAuth(t *testing.T) {
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
//nolint:forcetypeassert,errcheck // TODO: Do not force-type assert
|
||||||
app.Get("/testauth", func(c *fiber.Ctx) error {
|
app.Get("/testauth", func(c *fiber.Ctx) error {
|
||||||
username := c.Locals("username").(string)
|
username := c.Locals("username").(string)
|
||||||
password := c.Locals("password").(string)
|
password := c.Locals("password").(string)
|
||||||
|
@ -74,9 +75,9 @@ func Test_Middleware_BasicAuth(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
// Base64 encode credentials for http auth header
|
// Base64 encode credentials for http auth header
|
||||||
creds := b64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", tt.username, tt.password)))
|
creds := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", tt.username, tt.password)))
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/testauth", nil)
|
req := httptest.NewRequest(fiber.MethodGet, "/testauth", nil)
|
||||||
req.Header.Add("Authorization", "Basic "+creds)
|
req.Header.Add("Authorization", "Basic "+creds)
|
||||||
resp, err := app.Test(req)
|
resp, err := app.Test(req)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
@ -108,7 +109,7 @@ func Benchmark_Middleware_BasicAuth(b *testing.B) {
|
||||||
h := app.Handler()
|
h := app.Handler()
|
||||||
|
|
||||||
fctx := &fasthttp.RequestCtx{}
|
fctx := &fasthttp.RequestCtx{}
|
||||||
fctx.Request.Header.SetMethod("GET")
|
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||||
fctx.Request.SetRequestURI("/")
|
fctx.Request.SetRequestURI("/")
|
||||||
fctx.Request.Header.Set(fiber.HeaderAuthorization, "basic am9objpkb2U=") // john:doe
|
fctx.Request.Header.Set(fiber.HeaderAuthorization, "basic am9objpkb2U=") // john:doe
|
||||||
|
|
||||||
|
|
|
@ -53,6 +53,8 @@ type Config struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConfigDefault is the default config
|
// ConfigDefault is the default config
|
||||||
|
//
|
||||||
|
//nolint:gochecknoglobals // Using a global var is fine here
|
||||||
var ConfigDefault = Config{
|
var ConfigDefault = Config{
|
||||||
Next: nil,
|
Next: nil,
|
||||||
Users: map[string]string{},
|
Users: map[string]string{},
|
||||||
|
|
|
@ -34,6 +34,7 @@ const (
|
||||||
noStore = "no-store"
|
noStore = "no-store"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
//nolint:gochecknoglobals // TODO: Do not use a global var here
|
||||||
var ignoreHeaders = map[string]interface{}{
|
var ignoreHeaders = map[string]interface{}{
|
||||||
"Connection": nil,
|
"Connection": nil,
|
||||||
"Keep-Alive": nil,
|
"Keep-Alive": nil,
|
||||||
|
@ -43,8 +44,8 @@ var ignoreHeaders = map[string]interface{}{
|
||||||
"Trailers": nil,
|
"Trailers": nil,
|
||||||
"Transfer-Encoding": nil,
|
"Transfer-Encoding": nil,
|
||||||
"Upgrade": nil,
|
"Upgrade": nil,
|
||||||
"Content-Type": nil, // already stored explicitely by the cache manager
|
"Content-Type": nil, // already stored explicitly by the cache manager
|
||||||
"Content-Encoding": nil, // already stored explicitely by the cache manager
|
"Content-Encoding": nil, // already stored explicitly by the cache manager
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new middleware handler
|
// New creates a new middleware handler
|
||||||
|
@ -69,7 +70,7 @@ func New(config ...Config) fiber.Handler {
|
||||||
// Create indexed heap for tracking expirations ( see heap.go )
|
// Create indexed heap for tracking expirations ( see heap.go )
|
||||||
heap := &indexedHeap{}
|
heap := &indexedHeap{}
|
||||||
// count stored bytes (sizes of response bodies)
|
// count stored bytes (sizes of response bodies)
|
||||||
var storedBytes uint = 0
|
var storedBytes uint
|
||||||
|
|
||||||
// Update timestamp in the configured interval
|
// Update timestamp in the configured interval
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -81,10 +82,10 @@ func New(config ...Config) fiber.Handler {
|
||||||
|
|
||||||
// Delete key from both manager and storage
|
// Delete key from both manager and storage
|
||||||
deleteKey := func(dkey string) {
|
deleteKey := func(dkey string) {
|
||||||
manager.delete(dkey)
|
manager.del(dkey)
|
||||||
// External storage saves body data with different key
|
// External storage saves body data with different key
|
||||||
if cfg.Storage != nil {
|
if cfg.Storage != nil {
|
||||||
manager.delete(dkey + "_body")
|
manager.del(dkey + "_body")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -205,7 +206,7 @@ func New(config ...Config) fiber.Handler {
|
||||||
if cfg.StoreResponseHeaders {
|
if cfg.StoreResponseHeaders {
|
||||||
e.headers = make(map[string][]byte)
|
e.headers = make(map[string][]byte)
|
||||||
c.Response().Header.VisitAll(
|
c.Response().Header.VisitAll(
|
||||||
func(key []byte, value []byte) {
|
func(key, value []byte) {
|
||||||
// create real copy
|
// create real copy
|
||||||
keyS := string(key)
|
keyS := string(key)
|
||||||
if _, ok := ignoreHeaders[keyS]; !ok {
|
if _, ok := ignoreHeaders[keyS]; !ok {
|
||||||
|
|
|
@ -7,7 +7,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
@ -18,6 +17,7 @@ import (
|
||||||
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
||||||
"github.com/gofiber/fiber/v2/middleware/etag"
|
"github.com/gofiber/fiber/v2/middleware/etag"
|
||||||
"github.com/gofiber/fiber/v2/utils"
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -35,10 +35,10 @@ func Test_Cache_CacheControl(t *testing.T) {
|
||||||
return c.SendString("Hello, World!")
|
return c.SendString("Hello, World!")
|
||||||
})
|
})
|
||||||
|
|
||||||
_, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, "public, max-age=10", resp.Header.Get(fiber.HeaderCacheControl))
|
utils.AssertEqual(t, "public, max-age=10", resp.Header.Get(fiber.HeaderCacheControl))
|
||||||
}
|
}
|
||||||
|
@ -53,7 +53,7 @@ func Test_Cache_Expired(t *testing.T) {
|
||||||
return c.SendString(fmt.Sprintf("%d", time.Now().UnixNano()))
|
return c.SendString(fmt.Sprintf("%d", time.Now().UnixNano()))
|
||||||
})
|
})
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
@ -61,7 +61,7 @@ func Test_Cache_Expired(t *testing.T) {
|
||||||
// Sleep until the cache is expired
|
// Sleep until the cache is expired
|
||||||
time.Sleep(3 * time.Second)
|
time.Sleep(3 * time.Second)
|
||||||
|
|
||||||
respCached, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
bodyCached, err := io.ReadAll(respCached.Body)
|
bodyCached, err := io.ReadAll(respCached.Body)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
@ -71,7 +71,7 @@ func Test_Cache_Expired(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Next response should be also cached
|
// Next response should be also cached
|
||||||
respCachedNextRound, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
respCachedNextRound, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
bodyCachedNextRound, err := io.ReadAll(respCachedNextRound.Body)
|
bodyCachedNextRound, err := io.ReadAll(respCachedNextRound.Body)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
@ -92,11 +92,11 @@ func Test_Cache(t *testing.T) {
|
||||||
return c.SendString(now)
|
return c.SendString(now)
|
||||||
})
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||||
resp, err := app.Test(req)
|
resp, err := app.Test(req)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
cachedReq := httptest.NewRequest("GET", "/", nil)
|
cachedReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||||
cachedResp, err := app.Test(cachedReq)
|
cachedResp, err := app.Test(cachedReq)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
|
@ -120,31 +120,31 @@ func Test_Cache_WithNoCacheRequestDirective(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
// Request id = 1
|
// Request id = 1
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||||
resp, err := app.Test(req)
|
resp, err := app.Test(req)
|
||||||
defer resp.Body.Close()
|
utils.AssertEqual(t, nil, err)
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache"))
|
utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache"))
|
||||||
utils.AssertEqual(t, []byte("1"), body)
|
utils.AssertEqual(t, []byte("1"), body)
|
||||||
// Response cached, entry id = 1
|
// Response cached, entry id = 1
|
||||||
|
|
||||||
// Request id = 2 without Cache-Control: no-cache
|
// Request id = 2 without Cache-Control: no-cache
|
||||||
cachedReq := httptest.NewRequest("GET", "/?id=2", nil)
|
cachedReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
|
||||||
cachedResp, err := app.Test(cachedReq)
|
cachedResp, err := app.Test(cachedReq)
|
||||||
defer cachedResp.Body.Close()
|
utils.AssertEqual(t, nil, err)
|
||||||
cachedBody, _ := io.ReadAll(cachedResp.Body)
|
cachedBody, err := io.ReadAll(cachedResp.Body)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, cacheHit, cachedResp.Header.Get("X-Cache"))
|
utils.AssertEqual(t, cacheHit, cachedResp.Header.Get("X-Cache"))
|
||||||
utils.AssertEqual(t, []byte("1"), cachedBody)
|
utils.AssertEqual(t, []byte("1"), cachedBody)
|
||||||
// Response not cached, returns cached response, entry id = 1
|
// Response not cached, returns cached response, entry id = 1
|
||||||
|
|
||||||
// Request id = 2 with Cache-Control: no-cache
|
// Request id = 2 with Cache-Control: no-cache
|
||||||
noCacheReq := httptest.NewRequest("GET", "/?id=2", nil)
|
noCacheReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
|
||||||
noCacheReq.Header.Set(fiber.HeaderCacheControl, noCache)
|
noCacheReq.Header.Set(fiber.HeaderCacheControl, noCache)
|
||||||
noCacheResp, err := app.Test(noCacheReq)
|
noCacheResp, err := app.Test(noCacheReq)
|
||||||
defer noCacheResp.Body.Close()
|
utils.AssertEqual(t, nil, err)
|
||||||
noCacheBody, _ := io.ReadAll(noCacheResp.Body)
|
noCacheBody, err := io.ReadAll(noCacheResp.Body)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, cacheMiss, noCacheResp.Header.Get("X-Cache"))
|
utils.AssertEqual(t, cacheMiss, noCacheResp.Header.Get("X-Cache"))
|
||||||
utils.AssertEqual(t, []byte("2"), noCacheBody)
|
utils.AssertEqual(t, []byte("2"), noCacheBody)
|
||||||
|
@ -152,21 +152,21 @@ func Test_Cache_WithNoCacheRequestDirective(t *testing.T) {
|
||||||
|
|
||||||
/* Check Test_Cache_WithETagAndNoCacheRequestDirective */
|
/* Check Test_Cache_WithETagAndNoCacheRequestDirective */
|
||||||
// Request id = 2 with Cache-Control: no-cache again
|
// Request id = 2 with Cache-Control: no-cache again
|
||||||
noCacheReq1 := httptest.NewRequest("GET", "/?id=2", nil)
|
noCacheReq1 := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
|
||||||
noCacheReq1.Header.Set(fiber.HeaderCacheControl, noCache)
|
noCacheReq1.Header.Set(fiber.HeaderCacheControl, noCache)
|
||||||
noCacheResp1, err := app.Test(noCacheReq1)
|
noCacheResp1, err := app.Test(noCacheReq1)
|
||||||
defer noCacheResp1.Body.Close()
|
utils.AssertEqual(t, nil, err)
|
||||||
noCacheBody1, _ := io.ReadAll(noCacheResp1.Body)
|
noCacheBody1, err := io.ReadAll(noCacheResp1.Body)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, cacheMiss, noCacheResp1.Header.Get("X-Cache"))
|
utils.AssertEqual(t, cacheMiss, noCacheResp1.Header.Get("X-Cache"))
|
||||||
utils.AssertEqual(t, []byte("2"), noCacheBody1)
|
utils.AssertEqual(t, []byte("2"), noCacheBody1)
|
||||||
// Response cached, returns updated response, entry = 2
|
// Response cached, returns updated response, entry = 2
|
||||||
|
|
||||||
// Request id = 1 without Cache-Control: no-cache
|
// Request id = 1 without Cache-Control: no-cache
|
||||||
cachedReq1 := httptest.NewRequest("GET", "/", nil)
|
cachedReq1 := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||||
cachedResp1, err := app.Test(cachedReq1)
|
cachedResp1, err := app.Test(cachedReq1)
|
||||||
defer cachedResp1.Body.Close()
|
utils.AssertEqual(t, nil, err)
|
||||||
cachedBody1, _ := io.ReadAll(cachedResp1.Body)
|
cachedBody1, err := io.ReadAll(cachedResp1.Body)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, cacheHit, cachedResp1.Header.Get("X-Cache"))
|
utils.AssertEqual(t, cacheHit, cachedResp1.Header.Get("X-Cache"))
|
||||||
utils.AssertEqual(t, []byte("2"), cachedBody1)
|
utils.AssertEqual(t, []byte("2"), cachedBody1)
|
||||||
|
@ -188,7 +188,7 @@ func Test_Cache_WithETagAndNoCacheRequestDirective(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
// Request id = 1
|
// Request id = 1
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||||
resp, err := app.Test(req)
|
resp, err := app.Test(req)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache"))
|
utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache"))
|
||||||
|
@ -199,7 +199,7 @@ func Test_Cache_WithETagAndNoCacheRequestDirective(t *testing.T) {
|
||||||
etagToken := resp.Header.Get("Etag")
|
etagToken := resp.Header.Get("Etag")
|
||||||
|
|
||||||
// Request id = 2 with ETag but without Cache-Control: no-cache
|
// Request id = 2 with ETag but without Cache-Control: no-cache
|
||||||
cachedReq := httptest.NewRequest("GET", "/?id=2", nil)
|
cachedReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
|
||||||
cachedReq.Header.Set(fiber.HeaderIfNoneMatch, etagToken)
|
cachedReq.Header.Set(fiber.HeaderIfNoneMatch, etagToken)
|
||||||
cachedResp, err := app.Test(cachedReq)
|
cachedResp, err := app.Test(cachedReq)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
@ -208,7 +208,7 @@ func Test_Cache_WithETagAndNoCacheRequestDirective(t *testing.T) {
|
||||||
// Response not cached, returns cached response, entry id = 1, status not modified
|
// Response not cached, returns cached response, entry id = 1, status not modified
|
||||||
|
|
||||||
// Request id = 2 with ETag and Cache-Control: no-cache
|
// Request id = 2 with ETag and Cache-Control: no-cache
|
||||||
noCacheReq := httptest.NewRequest("GET", "/?id=2", nil)
|
noCacheReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
|
||||||
noCacheReq.Header.Set(fiber.HeaderCacheControl, noCache)
|
noCacheReq.Header.Set(fiber.HeaderCacheControl, noCache)
|
||||||
noCacheReq.Header.Set(fiber.HeaderIfNoneMatch, etagToken)
|
noCacheReq.Header.Set(fiber.HeaderIfNoneMatch, etagToken)
|
||||||
noCacheResp, err := app.Test(noCacheReq)
|
noCacheResp, err := app.Test(noCacheReq)
|
||||||
|
@ -221,7 +221,7 @@ func Test_Cache_WithETagAndNoCacheRequestDirective(t *testing.T) {
|
||||||
etagToken = noCacheResp.Header.Get("Etag")
|
etagToken = noCacheResp.Header.Get("Etag")
|
||||||
|
|
||||||
// Request id = 2 with ETag and Cache-Control: no-cache again
|
// Request id = 2 with ETag and Cache-Control: no-cache again
|
||||||
noCacheReq1 := httptest.NewRequest("GET", "/?id=2", nil)
|
noCacheReq1 := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
|
||||||
noCacheReq1.Header.Set(fiber.HeaderCacheControl, noCache)
|
noCacheReq1.Header.Set(fiber.HeaderCacheControl, noCache)
|
||||||
noCacheReq1.Header.Set(fiber.HeaderIfNoneMatch, etagToken)
|
noCacheReq1.Header.Set(fiber.HeaderIfNoneMatch, etagToken)
|
||||||
noCacheResp1, err := app.Test(noCacheReq1)
|
noCacheResp1, err := app.Test(noCacheReq1)
|
||||||
|
@ -231,7 +231,7 @@ func Test_Cache_WithETagAndNoCacheRequestDirective(t *testing.T) {
|
||||||
// Response cached, returns updated response, entry id = 2, status not modified
|
// Response cached, returns updated response, entry id = 2, status not modified
|
||||||
|
|
||||||
// Request id = 1 without ETag and Cache-Control: no-cache
|
// Request id = 1 without ETag and Cache-Control: no-cache
|
||||||
cachedReq1 := httptest.NewRequest("GET", "/", nil)
|
cachedReq1 := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||||
cachedResp1, err := app.Test(cachedReq1)
|
cachedResp1, err := app.Test(cachedReq1)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, cacheHit, cachedResp1.Header.Get("X-Cache"))
|
utils.AssertEqual(t, cacheHit, cachedResp1.Header.Get("X-Cache"))
|
||||||
|
@ -251,11 +251,11 @@ func Test_Cache_WithNoStoreRequestDirective(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
// Request id = 2
|
// Request id = 2
|
||||||
noStoreReq := httptest.NewRequest("GET", "/?id=2", nil)
|
noStoreReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
|
||||||
noStoreReq.Header.Set(fiber.HeaderCacheControl, noStore)
|
noStoreReq.Header.Set(fiber.HeaderCacheControl, noStore)
|
||||||
noStoreResp, err := app.Test(noStoreReq)
|
noStoreResp, err := app.Test(noStoreReq)
|
||||||
defer noStoreResp.Body.Close()
|
utils.AssertEqual(t, nil, err)
|
||||||
noStoreBody, _ := io.ReadAll(noStoreResp.Body)
|
noStoreBody, err := io.ReadAll(noStoreResp.Body)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, []byte("2"), noStoreBody)
|
utils.AssertEqual(t, []byte("2"), noStoreBody)
|
||||||
// Response not cached, returns updated response
|
// Response not cached, returns updated response
|
||||||
|
@ -278,11 +278,11 @@ func Test_Cache_WithSeveralRequests(t *testing.T) {
|
||||||
for runs := 0; runs < 10; runs++ {
|
for runs := 0; runs < 10; runs++ {
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
func(id int) {
|
func(id int) {
|
||||||
rsp, err := app.Test(httptest.NewRequest(http.MethodGet, fmt.Sprintf("/%d", id), nil))
|
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, fmt.Sprintf("/%d", id), nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
defer func(Body io.ReadCloser) {
|
defer func(body io.ReadCloser) {
|
||||||
err := Body.Close()
|
err := body.Close()
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
}(rsp.Body)
|
}(rsp.Body)
|
||||||
|
|
||||||
|
@ -311,11 +311,11 @@ func Test_Cache_Invalid_Expiration(t *testing.T) {
|
||||||
return c.SendString(now)
|
return c.SendString(now)
|
||||||
})
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||||
resp, err := app.Test(req)
|
resp, err := app.Test(req)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
cachedReq := httptest.NewRequest("GET", "/", nil)
|
cachedReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||||
cachedResp, err := app.Test(cachedReq)
|
cachedResp, err := app.Test(cachedReq)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
|
@ -342,25 +342,25 @@ func Test_Cache_Get(t *testing.T) {
|
||||||
return c.SendString(c.Query("cache"))
|
return c.SendString(c.Query("cache"))
|
||||||
})
|
})
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("POST", "/?cache=123", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=123", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, "123", string(body))
|
utils.AssertEqual(t, "123", string(body))
|
||||||
|
|
||||||
resp, err = app.Test(httptest.NewRequest("POST", "/?cache=12345", nil))
|
resp, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=12345", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
body, err = io.ReadAll(resp.Body)
|
body, err = io.ReadAll(resp.Body)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, "12345", string(body))
|
utils.AssertEqual(t, "12345", string(body))
|
||||||
|
|
||||||
resp, err = app.Test(httptest.NewRequest("GET", "/get?cache=123", nil))
|
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=123", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
body, err = io.ReadAll(resp.Body)
|
body, err = io.ReadAll(resp.Body)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, "123", string(body))
|
utils.AssertEqual(t, "123", string(body))
|
||||||
|
|
||||||
resp, err = app.Test(httptest.NewRequest("GET", "/get?cache=12345", nil))
|
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=12345", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
body, err = io.ReadAll(resp.Body)
|
body, err = io.ReadAll(resp.Body)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
@ -384,25 +384,25 @@ func Test_Cache_Post(t *testing.T) {
|
||||||
return c.SendString(c.Query("cache"))
|
return c.SendString(c.Query("cache"))
|
||||||
})
|
})
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("POST", "/?cache=123", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=123", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, "123", string(body))
|
utils.AssertEqual(t, "123", string(body))
|
||||||
|
|
||||||
resp, err = app.Test(httptest.NewRequest("POST", "/?cache=12345", nil))
|
resp, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=12345", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
body, err = io.ReadAll(resp.Body)
|
body, err = io.ReadAll(resp.Body)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, "123", string(body))
|
utils.AssertEqual(t, "123", string(body))
|
||||||
|
|
||||||
resp, err = app.Test(httptest.NewRequest("GET", "/get?cache=123", nil))
|
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=123", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
body, err = io.ReadAll(resp.Body)
|
body, err = io.ReadAll(resp.Body)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, "123", string(body))
|
utils.AssertEqual(t, "123", string(body))
|
||||||
|
|
||||||
resp, err = app.Test(httptest.NewRequest("GET", "/get?cache=12345", nil))
|
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=12345", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
body, err = io.ReadAll(resp.Body)
|
body, err = io.ReadAll(resp.Body)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
@ -420,14 +420,14 @@ func Test_Cache_NothingToCache(t *testing.T) {
|
||||||
return c.SendString(time.Now().String())
|
return c.SendString(time.Now().String())
|
||||||
})
|
})
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
respCached, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
bodyCached, err := io.ReadAll(respCached.Body)
|
bodyCached, err := io.ReadAll(respCached.Body)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
@ -457,22 +457,22 @@ func Test_Cache_CustomNext(t *testing.T) {
|
||||||
return c.Status(fiber.StatusInternalServerError).SendString(time.Now().String())
|
return c.Status(fiber.StatusInternalServerError).SendString(time.Now().String())
|
||||||
})
|
})
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
respCached, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
bodyCached, err := io.ReadAll(respCached.Body)
|
bodyCached, err := io.ReadAll(respCached.Body)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, true, bytes.Equal(body, bodyCached))
|
utils.AssertEqual(t, true, bytes.Equal(body, bodyCached))
|
||||||
utils.AssertEqual(t, true, respCached.Header.Get(fiber.HeaderCacheControl) != "")
|
utils.AssertEqual(t, true, respCached.Header.Get(fiber.HeaderCacheControl) != "")
|
||||||
|
|
||||||
_, err = app.Test(httptest.NewRequest("GET", "/error", nil))
|
_, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/error", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
errRespCached, err := app.Test(httptest.NewRequest("GET", "/error", nil))
|
errRespCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/error", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, true, errRespCached.Header.Get(fiber.HeaderCacheControl) == "")
|
utils.AssertEqual(t, true, errRespCached.Header.Get(fiber.HeaderCacheControl) == "")
|
||||||
}
|
}
|
||||||
|
@ -491,7 +491,7 @@ func Test_CustomKey(t *testing.T) {
|
||||||
return c.SendString("hi")
|
return c.SendString("hi")
|
||||||
})
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||||
_, err := app.Test(req)
|
_, err := app.Test(req)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, true, called)
|
utils.AssertEqual(t, true, called)
|
||||||
|
@ -505,7 +505,9 @@ func Test_CustomExpiration(t *testing.T) {
|
||||||
var newCacheTime int
|
var newCacheTime int
|
||||||
app.Use(New(Config{ExpirationGenerator: func(c *fiber.Ctx, cfg *Config) time.Duration {
|
app.Use(New(Config{ExpirationGenerator: func(c *fiber.Ctx, cfg *Config) time.Duration {
|
||||||
called = true
|
called = true
|
||||||
newCacheTime, _ = strconv.Atoi(c.GetRespHeader("Cache-Time", "600"))
|
var err error
|
||||||
|
newCacheTime, err = strconv.Atoi(c.GetRespHeader("Cache-Time", "600"))
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
return time.Second * time.Duration(newCacheTime)
|
return time.Second * time.Duration(newCacheTime)
|
||||||
}}))
|
}}))
|
||||||
|
|
||||||
|
@ -515,7 +517,7 @@ func Test_CustomExpiration(t *testing.T) {
|
||||||
return c.SendString(now)
|
return c.SendString(now)
|
||||||
})
|
})
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, true, called)
|
utils.AssertEqual(t, true, called)
|
||||||
utils.AssertEqual(t, 1, newCacheTime)
|
utils.AssertEqual(t, 1, newCacheTime)
|
||||||
|
@ -523,7 +525,7 @@ func Test_CustomExpiration(t *testing.T) {
|
||||||
// Sleep until the cache is expired
|
// Sleep until the cache is expired
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(1 * time.Second)
|
||||||
|
|
||||||
cachedResp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
cachedResp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
@ -536,7 +538,7 @@ func Test_CustomExpiration(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Next response should be cached
|
// Next response should be cached
|
||||||
cachedRespNextRound, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
cachedRespNextRound, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
cachedBodyNextRound, err := io.ReadAll(cachedRespNextRound.Body)
|
cachedBodyNextRound, err := io.ReadAll(cachedRespNextRound.Body)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
@ -559,12 +561,12 @@ func Test_AdditionalE2EResponseHeaders(t *testing.T) {
|
||||||
return c.SendString("hi")
|
return c.SendString("hi")
|
||||||
})
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||||
resp, err := app.Test(req)
|
resp, err := app.Test(req)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, "foobar", resp.Header.Get("X-Foobar"))
|
utils.AssertEqual(t, "foobar", resp.Header.Get("X-Foobar"))
|
||||||
|
|
||||||
req = httptest.NewRequest("GET", "/", nil)
|
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||||
resp, err = app.Test(req)
|
resp, err = app.Test(req)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, "foobar", resp.Header.Get("X-Foobar"))
|
utils.AssertEqual(t, "foobar", resp.Header.Get("X-Foobar"))
|
||||||
|
@ -594,19 +596,19 @@ func Test_CacheHeader(t *testing.T) {
|
||||||
return c.Status(fiber.StatusInternalServerError).SendString(time.Now().String())
|
return c.Status(fiber.StatusInternalServerError).SendString(time.Now().String())
|
||||||
})
|
})
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache"))
|
utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache"))
|
||||||
|
|
||||||
resp, err = app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, cacheHit, resp.Header.Get("X-Cache"))
|
utils.AssertEqual(t, cacheHit, resp.Header.Get("X-Cache"))
|
||||||
|
|
||||||
resp, err = app.Test(httptest.NewRequest("POST", "/?cache=12345", nil))
|
resp, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=12345", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, cacheUnreachable, resp.Header.Get("X-Cache"))
|
utils.AssertEqual(t, cacheUnreachable, resp.Header.Get("X-Cache"))
|
||||||
|
|
||||||
errRespCached, err := app.Test(httptest.NewRequest("GET", "/error", nil))
|
errRespCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/error", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, cacheUnreachable, errRespCached.Header.Get("X-Cache"))
|
utils.AssertEqual(t, cacheUnreachable, errRespCached.Header.Get("X-Cache"))
|
||||||
}
|
}
|
||||||
|
@ -622,12 +624,12 @@ func Test_Cache_WithHead(t *testing.T) {
|
||||||
return c.SendString(now)
|
return c.SendString(now)
|
||||||
})
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest("HEAD", "/", nil)
|
req := httptest.NewRequest(fiber.MethodHead, "/", nil)
|
||||||
resp, err := app.Test(req)
|
resp, err := app.Test(req)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache"))
|
utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache"))
|
||||||
|
|
||||||
cachedReq := httptest.NewRequest("HEAD", "/", nil)
|
cachedReq := httptest.NewRequest(fiber.MethodHead, "/", nil)
|
||||||
cachedResp, err := app.Test(cachedReq)
|
cachedResp, err := app.Test(cachedReq)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, cacheHit, cachedResp.Header.Get("X-Cache"))
|
utils.AssertEqual(t, cacheHit, cachedResp.Header.Get("X-Cache"))
|
||||||
|
@ -649,28 +651,28 @@ func Test_Cache_WithHeadThenGet(t *testing.T) {
|
||||||
return c.SendString(c.Query("cache"))
|
return c.SendString(c.Query("cache"))
|
||||||
})
|
})
|
||||||
|
|
||||||
headResp, err := app.Test(httptest.NewRequest("HEAD", "/?cache=123", nil))
|
headResp, err := app.Test(httptest.NewRequest(fiber.MethodHead, "/?cache=123", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
headBody, err := io.ReadAll(headResp.Body)
|
headBody, err := io.ReadAll(headResp.Body)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, "", string(headBody))
|
utils.AssertEqual(t, "", string(headBody))
|
||||||
utils.AssertEqual(t, cacheMiss, headResp.Header.Get("X-Cache"))
|
utils.AssertEqual(t, cacheMiss, headResp.Header.Get("X-Cache"))
|
||||||
|
|
||||||
headResp, err = app.Test(httptest.NewRequest("HEAD", "/?cache=123", nil))
|
headResp, err = app.Test(httptest.NewRequest(fiber.MethodHead, "/?cache=123", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
headBody, err = io.ReadAll(headResp.Body)
|
headBody, err = io.ReadAll(headResp.Body)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, "", string(headBody))
|
utils.AssertEqual(t, "", string(headBody))
|
||||||
utils.AssertEqual(t, cacheHit, headResp.Header.Get("X-Cache"))
|
utils.AssertEqual(t, cacheHit, headResp.Header.Get("X-Cache"))
|
||||||
|
|
||||||
getResp, err := app.Test(httptest.NewRequest("GET", "/?cache=123", nil))
|
getResp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?cache=123", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
getBody, err := io.ReadAll(getResp.Body)
|
getBody, err := io.ReadAll(getResp.Body)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, "123", string(getBody))
|
utils.AssertEqual(t, "123", string(getBody))
|
||||||
utils.AssertEqual(t, cacheMiss, getResp.Header.Get("X-Cache"))
|
utils.AssertEqual(t, cacheMiss, getResp.Header.Get("X-Cache"))
|
||||||
|
|
||||||
getResp, err = app.Test(httptest.NewRequest("GET", "/?cache=123", nil))
|
getResp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/?cache=123", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
getBody, err = io.ReadAll(getResp.Body)
|
getBody, err = io.ReadAll(getResp.Body)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
@ -691,7 +693,7 @@ func Test_CustomCacheHeader(t *testing.T) {
|
||||||
return c.SendString("Hello, World!")
|
return c.SendString("Hello, World!")
|
||||||
})
|
})
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, cacheMiss, resp.Header.Get("Cache-Status"))
|
utils.AssertEqual(t, cacheMiss, resp.Header.Get("Cache-Status"))
|
||||||
}
|
}
|
||||||
|
@ -702,7 +704,7 @@ func Test_CustomCacheHeader(t *testing.T) {
|
||||||
func stableAscendingExpiration() func(c1 *fiber.Ctx, c2 *Config) time.Duration {
|
func stableAscendingExpiration() func(c1 *fiber.Ctx, c2 *Config) time.Duration {
|
||||||
i := 0
|
i := 0
|
||||||
return func(c1 *fiber.Ctx, c2 *Config) time.Duration {
|
return func(c1 *fiber.Ctx, c2 *Config) time.Duration {
|
||||||
i += 1
|
i++
|
||||||
return time.Hour * time.Duration(i)
|
return time.Hour * time.Duration(i)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -738,7 +740,7 @@ func Test_Cache_MaxBytesOrder(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx, tcase := range cases {
|
for idx, tcase := range cases {
|
||||||
rsp, err := app.Test(httptest.NewRequest("GET", tcase[0], nil))
|
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, tcase[0], nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, tcase[1], rsp.Header.Get("X-Cache"), fmt.Sprintf("Case %v", idx))
|
utils.AssertEqual(t, tcase[1], rsp.Header.Get("X-Cache"), fmt.Sprintf("Case %v", idx))
|
||||||
}
|
}
|
||||||
|
@ -756,7 +758,8 @@ func Test_Cache_MaxBytesSizes(t *testing.T) {
|
||||||
|
|
||||||
app.Get("/*", func(c *fiber.Ctx) error {
|
app.Get("/*", func(c *fiber.Ctx) error {
|
||||||
path := c.Context().URI().LastPathSegment()
|
path := c.Context().URI().LastPathSegment()
|
||||||
size, _ := strconv.Atoi(string(path))
|
size, err := strconv.Atoi(string(path))
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
return c.Send(make([]byte, size))
|
return c.Send(make([]byte, size))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -772,7 +775,7 @@ func Test_Cache_MaxBytesSizes(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx, tcase := range cases {
|
for idx, tcase := range cases {
|
||||||
rsp, err := app.Test(httptest.NewRequest("GET", tcase[0], nil))
|
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, tcase[0], nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, tcase[1], rsp.Header.Get("X-Cache"), fmt.Sprintf("Case %v", idx))
|
utils.AssertEqual(t, tcase[1], rsp.Header.Get("X-Cache"), fmt.Sprintf("Case %v", idx))
|
||||||
}
|
}
|
||||||
|
@ -785,14 +788,14 @@ func Benchmark_Cache(b *testing.B) {
|
||||||
app.Use(New())
|
app.Use(New())
|
||||||
|
|
||||||
app.Get("/demo", func(c *fiber.Ctx) error {
|
app.Get("/demo", func(c *fiber.Ctx) error {
|
||||||
data, _ := os.ReadFile("../../.github/README.md")
|
data, _ := os.ReadFile("../../.github/README.md") //nolint:errcheck // We're inside a benchmark
|
||||||
return c.Status(fiber.StatusTeapot).Send(data)
|
return c.Status(fiber.StatusTeapot).Send(data)
|
||||||
})
|
})
|
||||||
|
|
||||||
h := app.Handler()
|
h := app.Handler()
|
||||||
|
|
||||||
fctx := &fasthttp.RequestCtx{}
|
fctx := &fasthttp.RequestCtx{}
|
||||||
fctx.Request.Header.SetMethod("GET")
|
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||||
fctx.Request.SetRequestURI("/demo")
|
fctx.Request.SetRequestURI("/demo")
|
||||||
|
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
|
@ -815,14 +818,14 @@ func Benchmark_Cache_Storage(b *testing.B) {
|
||||||
}))
|
}))
|
||||||
|
|
||||||
app.Get("/demo", func(c *fiber.Ctx) error {
|
app.Get("/demo", func(c *fiber.Ctx) error {
|
||||||
data, _ := os.ReadFile("../../.github/README.md")
|
data, _ := os.ReadFile("../../.github/README.md") //nolint:errcheck // We're inside a benchmark
|
||||||
return c.Status(fiber.StatusTeapot).Send(data)
|
return c.Status(fiber.StatusTeapot).Send(data)
|
||||||
})
|
})
|
||||||
|
|
||||||
h := app.Handler()
|
h := app.Handler()
|
||||||
|
|
||||||
fctx := &fasthttp.RequestCtx{}
|
fctx := &fasthttp.RequestCtx{}
|
||||||
fctx.Request.Header.SetMethod("GET")
|
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||||
fctx.Request.SetRequestURI("/demo")
|
fctx.Request.SetRequestURI("/demo")
|
||||||
|
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
|
@ -850,7 +853,7 @@ func Benchmark_Cache_AdditionalHeaders(b *testing.B) {
|
||||||
h := app.Handler()
|
h := app.Handler()
|
||||||
|
|
||||||
fctx := &fasthttp.RequestCtx{}
|
fctx := &fasthttp.RequestCtx{}
|
||||||
fctx.Request.Header.SetMethod("GET")
|
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||||
fctx.Request.SetRequestURI("/demo")
|
fctx.Request.SetRequestURI("/demo")
|
||||||
|
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
|
@ -882,7 +885,7 @@ func Benchmark_Cache_MaxSize(b *testing.B) {
|
||||||
|
|
||||||
h := app.Handler()
|
h := app.Handler()
|
||||||
fctx := &fasthttp.RequestCtx{}
|
fctx := &fasthttp.RequestCtx{}
|
||||||
fctx.Request.Header.SetMethod("GET")
|
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||||
|
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
package cache
|
package cache
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"log"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
|
@ -49,10 +49,10 @@ type Config struct {
|
||||||
// Default: an in memory store for this process only
|
// Default: an in memory store for this process only
|
||||||
Storage fiber.Storage
|
Storage fiber.Storage
|
||||||
|
|
||||||
// Deprecated, use Storage instead
|
// Deprecated: Use Storage instead
|
||||||
Store fiber.Storage
|
Store fiber.Storage
|
||||||
|
|
||||||
// Deprecated, use KeyGenerator instead
|
// Deprecated: Use KeyGenerator instead
|
||||||
Key func(*fiber.Ctx) string
|
Key func(*fiber.Ctx) string
|
||||||
|
|
||||||
// allows you to store additional headers generated by next middlewares & handler
|
// allows you to store additional headers generated by next middlewares & handler
|
||||||
|
@ -75,6 +75,8 @@ type Config struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConfigDefault is the default config
|
// ConfigDefault is the default config
|
||||||
|
//
|
||||||
|
//nolint:gochecknoglobals // Using a global var is fine here
|
||||||
var ConfigDefault = Config{
|
var ConfigDefault = Config{
|
||||||
Next: nil,
|
Next: nil,
|
||||||
Expiration: 1 * time.Minute,
|
Expiration: 1 * time.Minute,
|
||||||
|
@ -102,11 +104,11 @@ func configDefault(config ...Config) Config {
|
||||||
|
|
||||||
// Set default values
|
// Set default values
|
||||||
if cfg.Store != nil {
|
if cfg.Store != nil {
|
||||||
fmt.Println("[CACHE] Store is deprecated, please use Storage")
|
log.Printf("[CACHE] Store is deprecated, please use Storage\n")
|
||||||
cfg.Storage = cfg.Store
|
cfg.Storage = cfg.Store
|
||||||
}
|
}
|
||||||
if cfg.Key != nil {
|
if cfg.Key != nil {
|
||||||
fmt.Println("[CACHE] Key is deprecated, please use KeyGenerator")
|
log.Printf("[CACHE] Key is deprecated, please use KeyGenerator\n")
|
||||||
cfg.KeyGenerator = cfg.Key
|
cfg.KeyGenerator = cfg.Key
|
||||||
}
|
}
|
||||||
if cfg.Next == nil {
|
if cfg.Next == nil {
|
||||||
|
|
|
@ -41,7 +41,7 @@ func (h indexedHeap) Swap(i, j int) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *indexedHeap) Push(x interface{}) {
|
func (h *indexedHeap) Push(x interface{}) {
|
||||||
h.pushInternal(x.(heapEntry))
|
h.pushInternal(x.(heapEntry)) //nolint:forcetypeassert // Forced type assertion required to implement the heap.Interface interface
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *indexedHeap) Pop() interface{} {
|
func (h *indexedHeap) Pop() interface{} {
|
||||||
|
@ -65,7 +65,7 @@ func (h *indexedHeap) put(key string, exp uint64, bytes uint) int {
|
||||||
idx = h.entries[:n+1][n].idx
|
idx = h.entries[:n+1][n].idx
|
||||||
} else {
|
} else {
|
||||||
idx = h.maxidx
|
idx = h.maxidx
|
||||||
h.maxidx += 1
|
h.maxidx++
|
||||||
h.indices = append(h.indices, idx)
|
h.indices = append(h.indices, idx)
|
||||||
}
|
}
|
||||||
// Push manually to avoid allocation
|
// Push manually to avoid allocation
|
||||||
|
@ -77,7 +77,7 @@ func (h *indexedHeap) put(key string, exp uint64, bytes uint) int {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *indexedHeap) removeInternal(realIdx int) (string, uint) {
|
func (h *indexedHeap) removeInternal(realIdx int) (string, uint) {
|
||||||
x := heap.Remove(h, realIdx).(heapEntry)
|
x := heap.Remove(h, realIdx).(heapEntry) //nolint:forcetypeassert,errcheck // Forced type assertion required to implement the heap.Interface interface
|
||||||
return x.key, x.bytes
|
return x.key, x.bytes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -51,7 +51,7 @@ func newManager(storage fiber.Storage) *manager {
|
||||||
|
|
||||||
// acquire returns an *entry from the sync.Pool
|
// acquire returns an *entry from the sync.Pool
|
||||||
func (m *manager) acquire() *item {
|
func (m *manager) acquire() *item {
|
||||||
return m.pool.Get().(*item)
|
return m.pool.Get().(*item) //nolint:forcetypeassert // We store nothing else in the pool
|
||||||
}
|
}
|
||||||
|
|
||||||
// release and reset *entry to sync.Pool
|
// release and reset *entry to sync.Pool
|
||||||
|
@ -69,38 +69,47 @@ func (m *manager) release(e *item) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// get data from storage or memory
|
// get data from storage or memory
|
||||||
func (m *manager) get(key string) (it *item) {
|
func (m *manager) get(key string) *item {
|
||||||
|
var it *item
|
||||||
if m.storage != nil {
|
if m.storage != nil {
|
||||||
it = m.acquire()
|
it = m.acquire()
|
||||||
if raw, _ := m.storage.Get(key); raw != nil {
|
raw, err := m.storage.Get(key)
|
||||||
|
if err != nil {
|
||||||
|
return it
|
||||||
|
}
|
||||||
|
if raw != nil {
|
||||||
if _, err := it.UnmarshalMsg(raw); err != nil {
|
if _, err := it.UnmarshalMsg(raw); err != nil {
|
||||||
return
|
return it
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return it
|
||||||
}
|
}
|
||||||
if it, _ = m.memory.Get(key).(*item); it == nil {
|
if it, _ = m.memory.Get(key).(*item); it == nil { //nolint:errcheck // We store nothing else in the pool
|
||||||
it = m.acquire()
|
it = m.acquire()
|
||||||
|
return it
|
||||||
}
|
}
|
||||||
return
|
return it
|
||||||
}
|
}
|
||||||
|
|
||||||
// get raw data from storage or memory
|
// get raw data from storage or memory
|
||||||
func (m *manager) getRaw(key string) (raw []byte) {
|
func (m *manager) getRaw(key string) []byte {
|
||||||
|
var raw []byte
|
||||||
if m.storage != nil {
|
if m.storage != nil {
|
||||||
raw, _ = m.storage.Get(key)
|
raw, _ = m.storage.Get(key) //nolint:errcheck // TODO: Handle error here
|
||||||
} else {
|
} else {
|
||||||
raw, _ = m.memory.Get(key).([]byte)
|
raw, _ = m.memory.Get(key).([]byte) //nolint:errcheck // TODO: Handle error here
|
||||||
}
|
}
|
||||||
return
|
return raw
|
||||||
}
|
}
|
||||||
|
|
||||||
// set data to storage or memory
|
// set data to storage or memory
|
||||||
func (m *manager) set(key string, it *item, exp time.Duration) {
|
func (m *manager) set(key string, it *item, exp time.Duration) {
|
||||||
if m.storage != nil {
|
if m.storage != nil {
|
||||||
if raw, err := it.MarshalMsg(nil); err == nil {
|
if raw, err := it.MarshalMsg(nil); err == nil {
|
||||||
_ = m.storage.Set(key, raw, exp)
|
_ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Handle error here
|
||||||
}
|
}
|
||||||
|
// we can release data because it's serialized to database
|
||||||
|
m.release(it)
|
||||||
} else {
|
} else {
|
||||||
m.memory.Set(key, it, exp)
|
m.memory.Set(key, it, exp)
|
||||||
}
|
}
|
||||||
|
@ -109,16 +118,16 @@ func (m *manager) set(key string, it *item, exp time.Duration) {
|
||||||
// set data to storage or memory
|
// set data to storage or memory
|
||||||
func (m *manager) setRaw(key string, raw []byte, exp time.Duration) {
|
func (m *manager) setRaw(key string, raw []byte, exp time.Duration) {
|
||||||
if m.storage != nil {
|
if m.storage != nil {
|
||||||
_ = m.storage.Set(key, raw, exp)
|
_ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Handle error here
|
||||||
} else {
|
} else {
|
||||||
m.memory.Set(key, raw, exp)
|
m.memory.Set(key, raw, exp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// delete data from storage or memory
|
// delete data from storage or memory
|
||||||
func (m *manager) delete(key string) {
|
func (m *manager) del(key string) {
|
||||||
if m.storage != nil {
|
if m.storage != nil {
|
||||||
_ = m.storage.Delete(key)
|
_ = m.storage.Delete(key) //nolint:errcheck // TODO: Handle error here
|
||||||
} else {
|
} else {
|
||||||
m.memory.Delete(key)
|
m.memory.Delete(key)
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package compress
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -12,8 +12,10 @@ import (
|
||||||
"github.com/gofiber/fiber/v2/utils"
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
//nolint:gochecknoglobals // Using a global var is fine here
|
||||||
var filedata []byte
|
var filedata []byte
|
||||||
|
|
||||||
|
//nolint:gochecknoinits // init() is used to populate a global var from a README file
|
||||||
func init() {
|
func init() {
|
||||||
dat, err := os.ReadFile("../../.github/README.md")
|
dat, err := os.ReadFile("../../.github/README.md")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -34,7 +36,7 @@ func Test_Compress_Gzip(t *testing.T) {
|
||||||
return c.Send(filedata)
|
return c.Send(filedata)
|
||||||
})
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||||
req.Header.Set("Accept-Encoding", "gzip")
|
req.Header.Set("Accept-Encoding", "gzip")
|
||||||
|
|
||||||
resp, err := app.Test(req)
|
resp, err := app.Test(req)
|
||||||
|
@ -64,7 +66,7 @@ func Test_Compress_Different_Level(t *testing.T) {
|
||||||
return c.Send(filedata)
|
return c.Send(filedata)
|
||||||
})
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||||
req.Header.Set("Accept-Encoding", "gzip")
|
req.Header.Set("Accept-Encoding", "gzip")
|
||||||
|
|
||||||
resp, err := app.Test(req)
|
resp, err := app.Test(req)
|
||||||
|
@ -90,7 +92,7 @@ func Test_Compress_Deflate(t *testing.T) {
|
||||||
return c.Send(filedata)
|
return c.Send(filedata)
|
||||||
})
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||||
req.Header.Set("Accept-Encoding", "deflate")
|
req.Header.Set("Accept-Encoding", "deflate")
|
||||||
|
|
||||||
resp, err := app.Test(req)
|
resp, err := app.Test(req)
|
||||||
|
@ -114,7 +116,7 @@ func Test_Compress_Brotli(t *testing.T) {
|
||||||
return c.Send(filedata)
|
return c.Send(filedata)
|
||||||
})
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||||
req.Header.Set("Accept-Encoding", "br")
|
req.Header.Set("Accept-Encoding", "br")
|
||||||
|
|
||||||
resp, err := app.Test(req, 10000)
|
resp, err := app.Test(req, 10000)
|
||||||
|
@ -138,7 +140,7 @@ func Test_Compress_Disabled(t *testing.T) {
|
||||||
return c.Send(filedata)
|
return c.Send(filedata)
|
||||||
})
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||||
req.Header.Set("Accept-Encoding", "br")
|
req.Header.Set("Accept-Encoding", "br")
|
||||||
|
|
||||||
resp, err := app.Test(req)
|
resp, err := app.Test(req)
|
||||||
|
@ -162,7 +164,7 @@ func Test_Compress_Next_Error(t *testing.T) {
|
||||||
return errors.New("next error")
|
return errors.New("next error")
|
||||||
})
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||||
req.Header.Set("Accept-Encoding", "gzip")
|
req.Header.Set("Accept-Encoding", "gzip")
|
||||||
|
|
||||||
resp, err := app.Test(req)
|
resp, err := app.Test(req)
|
||||||
|
@ -185,7 +187,7 @@ func Test_Compress_Next(t *testing.T) {
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,6 +33,8 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConfigDefault is the default config
|
// ConfigDefault is the default config
|
||||||
|
//
|
||||||
|
//nolint:gochecknoglobals // Using a global var is fine here
|
||||||
var ConfigDefault = Config{
|
var ConfigDefault = Config{
|
||||||
Next: nil,
|
Next: nil,
|
||||||
Level: LevelDefault,
|
Level: LevelDefault,
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package cors
|
package cors
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
@ -54,6 +53,8 @@ type Config struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConfigDefault is the default config
|
// ConfigDefault is the default config
|
||||||
|
//
|
||||||
|
//nolint:gochecknoglobals // Using a global var is fine here
|
||||||
var ConfigDefault = Config{
|
var ConfigDefault = Config{
|
||||||
Next: nil,
|
Next: nil,
|
||||||
AllowOrigins: "*",
|
AllowOrigins: "*",
|
||||||
|
@ -128,7 +129,7 @@ func New(config ...Config) fiber.Handler {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Simple request
|
// Simple request
|
||||||
if c.Method() != http.MethodOptions {
|
if c.Method() != fiber.MethodOptions {
|
||||||
c.Vary(fiber.HeaderOrigin)
|
c.Vary(fiber.HeaderOrigin)
|
||||||
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
|
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/gofiber/fiber/v2/utils"
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -237,7 +238,7 @@ func Test_CORS_Next(t *testing.T) {
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package cors
|
package cors
|
||||||
|
|
||||||
import "strings"
|
import (
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
func matchScheme(domain, pattern string) bool {
|
func matchScheme(domain, pattern string) bool {
|
||||||
didx := strings.Index(domain, ":")
|
didx := strings.Index(domain, ":")
|
||||||
|
@ -20,18 +22,20 @@ func matchSubdomain(domain, pattern string) bool {
|
||||||
}
|
}
|
||||||
domAuth := domain[didx+3:]
|
domAuth := domain[didx+3:]
|
||||||
// to avoid long loop by invalid long domain
|
// to avoid long loop by invalid long domain
|
||||||
if len(domAuth) > 253 {
|
const maxDomainLen = 253
|
||||||
|
if len(domAuth) > maxDomainLen {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
patAuth := pattern[pidx+3:]
|
patAuth := pattern[pidx+3:]
|
||||||
|
|
||||||
domComp := strings.Split(domAuth, ".")
|
domComp := strings.Split(domAuth, ".")
|
||||||
patComp := strings.Split(patAuth, ".")
|
patComp := strings.Split(patAuth, ".")
|
||||||
for i := len(domComp)/2 - 1; i >= 0; i-- {
|
const divHalf = 2
|
||||||
|
for i := len(domComp)/divHalf - 1; i >= 0; i-- {
|
||||||
opp := len(domComp) - 1 - i
|
opp := len(domComp) - 1 - i
|
||||||
domComp[i], domComp[opp] = domComp[opp], domComp[i]
|
domComp[i], domComp[opp] = domComp[opp], domComp[i]
|
||||||
}
|
}
|
||||||
for i := len(patComp)/2 - 1; i >= 0; i-- {
|
for i := len(patComp)/divHalf - 1; i >= 0; i-- {
|
||||||
opp := len(patComp) - 1 - i
|
opp := len(patComp) - 1 - i
|
||||||
patComp[i], patComp[opp] = patComp[opp], patComp[i]
|
patComp[i], patComp[opp] = patComp[opp], patComp[i]
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
package csrf
|
package csrf
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"log"
|
||||||
"net/textproto"
|
"net/textproto"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
@ -80,13 +80,13 @@ type Config struct {
|
||||||
// Optional. Default: utils.UUID
|
// Optional. Default: utils.UUID
|
||||||
KeyGenerator func() string
|
KeyGenerator func() string
|
||||||
|
|
||||||
// Deprecated, please use Expiration
|
// Deprecated: Please use Expiration
|
||||||
CookieExpires time.Duration
|
CookieExpires time.Duration
|
||||||
|
|
||||||
// Deprecated, please use Cookie* related fields
|
// Deprecated: Please use Cookie* related fields
|
||||||
Cookie *fiber.Cookie
|
Cookie *fiber.Cookie
|
||||||
|
|
||||||
// Deprecated, please use KeyLookup
|
// Deprecated: Please use KeyLookup
|
||||||
TokenLookup string
|
TokenLookup string
|
||||||
|
|
||||||
// ErrorHandler is executed when an error is returned from fiber.Handler.
|
// ErrorHandler is executed when an error is returned from fiber.Handler.
|
||||||
|
@ -105,6 +105,8 @@ type Config struct {
|
||||||
const HeaderName = "X-Csrf-Token"
|
const HeaderName = "X-Csrf-Token"
|
||||||
|
|
||||||
// ConfigDefault is the default config
|
// ConfigDefault is the default config
|
||||||
|
//
|
||||||
|
//nolint:gochecknoglobals // Using a global var is fine here
|
||||||
var ConfigDefault = Config{
|
var ConfigDefault = Config{
|
||||||
KeyLookup: "header:" + HeaderName,
|
KeyLookup: "header:" + HeaderName,
|
||||||
CookieName: "csrf_",
|
CookieName: "csrf_",
|
||||||
|
@ -116,7 +118,7 @@ var ConfigDefault = Config{
|
||||||
}
|
}
|
||||||
|
|
||||||
// default ErrorHandler that process return error from fiber.Handler
|
// default ErrorHandler that process return error from fiber.Handler
|
||||||
var defaultErrorHandler = func(c *fiber.Ctx, err error) error {
|
func defaultErrorHandler(_ *fiber.Ctx, _ error) error {
|
||||||
return fiber.ErrForbidden
|
return fiber.ErrForbidden
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -132,15 +134,15 @@ func configDefault(config ...Config) Config {
|
||||||
|
|
||||||
// Set default values
|
// Set default values
|
||||||
if cfg.TokenLookup != "" {
|
if cfg.TokenLookup != "" {
|
||||||
fmt.Println("[CSRF] TokenLookup is deprecated, please use KeyLookup")
|
log.Printf("[CSRF] TokenLookup is deprecated, please use KeyLookup\n")
|
||||||
cfg.KeyLookup = cfg.TokenLookup
|
cfg.KeyLookup = cfg.TokenLookup
|
||||||
}
|
}
|
||||||
if int(cfg.CookieExpires.Seconds()) > 0 {
|
if int(cfg.CookieExpires.Seconds()) > 0 {
|
||||||
fmt.Println("[CSRF] CookieExpires is deprecated, please use Expiration")
|
log.Printf("[CSRF] CookieExpires is deprecated, please use Expiration\n")
|
||||||
cfg.Expiration = cfg.CookieExpires
|
cfg.Expiration = cfg.CookieExpires
|
||||||
}
|
}
|
||||||
if cfg.Cookie != nil {
|
if cfg.Cookie != nil {
|
||||||
fmt.Println("[CSRF] Cookie is deprecated, please use Cookie* related fields")
|
log.Printf("[CSRF] Cookie is deprecated, please use Cookie* related fields\n")
|
||||||
if cfg.Cookie.Name != "" {
|
if cfg.Cookie.Name != "" {
|
||||||
cfg.CookieName = cfg.Cookie.Name
|
cfg.CookieName = cfg.Cookie.Name
|
||||||
}
|
}
|
||||||
|
@ -178,7 +180,8 @@ func configDefault(config ...Config) Config {
|
||||||
// Generate the correct extractor to get the token from the correct location
|
// Generate the correct extractor to get the token from the correct location
|
||||||
selectors := strings.Split(cfg.KeyLookup, ":")
|
selectors := strings.Split(cfg.KeyLookup, ":")
|
||||||
|
|
||||||
if len(selectors) != 2 {
|
const numParts = 2
|
||||||
|
if len(selectors) != numParts {
|
||||||
panic("[CSRF] KeyLookup must in the form of <source>:<key>")
|
panic("[CSRF] KeyLookup must in the form of <source>:<key>")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,9 +7,7 @@ import (
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var errTokenNotFound = errors.New("csrf token not found")
|
||||||
errTokenNotFound = errors.New("csrf token not found")
|
|
||||||
)
|
|
||||||
|
|
||||||
// New creates a new middleware handler
|
// New creates a new middleware handler
|
||||||
func New(config ...Config) fiber.Handler {
|
func New(config ...Config) fiber.Handler {
|
||||||
|
@ -22,7 +20,7 @@ func New(config ...Config) fiber.Handler {
|
||||||
dummyValue := []byte{'+'}
|
dummyValue := []byte{'+'}
|
||||||
|
|
||||||
// Return new handler
|
// Return new handler
|
||||||
return func(c *fiber.Ctx) (err error) {
|
return func(c *fiber.Ctx) error {
|
||||||
// Don't execute middleware if Next returns true
|
// Don't execute middleware if Next returns true
|
||||||
if cfg.Next != nil && cfg.Next(c) {
|
if cfg.Next != nil && cfg.Next(c) {
|
||||||
return c.Next()
|
return c.Next()
|
||||||
|
@ -39,7 +37,7 @@ func New(config ...Config) fiber.Handler {
|
||||||
// Assume that anything not defined as 'safe' by RFC7231 needs protection
|
// Assume that anything not defined as 'safe' by RFC7231 needs protection
|
||||||
|
|
||||||
// Extract token from client request i.e. header, query, param, form or cookie
|
// Extract token from client request i.e. header, query, param, form or cookie
|
||||||
token, err = cfg.Extractor(c)
|
token, err := cfg.Extractor(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cfg.ErrorHandler(c, err)
|
return cfg.ErrorHandler(c, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/gofiber/fiber/v2/utils"
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -23,7 +24,7 @@ func Test_CSRF(t *testing.T) {
|
||||||
h := app.Handler()
|
h := app.Handler()
|
||||||
ctx := &fasthttp.RequestCtx{}
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
|
||||||
methods := [4]string{"GET", "HEAD", "OPTIONS", "TRACE"}
|
methods := [4]string{fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace}
|
||||||
|
|
||||||
for _, method := range methods {
|
for _, method := range methods {
|
||||||
// Generate CSRF token
|
// Generate CSRF token
|
||||||
|
@ -33,14 +34,14 @@ func Test_CSRF(t *testing.T) {
|
||||||
// Without CSRF cookie
|
// Without CSRF cookie
|
||||||
ctx.Request.Reset()
|
ctx.Request.Reset()
|
||||||
ctx.Response.Reset()
|
ctx.Response.Reset()
|
||||||
ctx.Request.Header.SetMethod("POST")
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||||
h(ctx)
|
h(ctx)
|
||||||
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
|
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
|
||||||
|
|
||||||
// Empty/invalid CSRF token
|
// Empty/invalid CSRF token
|
||||||
ctx.Request.Reset()
|
ctx.Request.Reset()
|
||||||
ctx.Response.Reset()
|
ctx.Response.Reset()
|
||||||
ctx.Request.Header.SetMethod("POST")
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||||
ctx.Request.Header.Set(HeaderName, "johndoe")
|
ctx.Request.Header.Set(HeaderName, "johndoe")
|
||||||
h(ctx)
|
h(ctx)
|
||||||
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
|
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
|
||||||
|
@ -55,7 +56,7 @@ func Test_CSRF(t *testing.T) {
|
||||||
|
|
||||||
ctx.Request.Reset()
|
ctx.Request.Reset()
|
||||||
ctx.Response.Reset()
|
ctx.Response.Reset()
|
||||||
ctx.Request.Header.SetMethod("POST")
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||||
ctx.Request.Header.Set(HeaderName, token)
|
ctx.Request.Header.Set(HeaderName, token)
|
||||||
h(ctx)
|
h(ctx)
|
||||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||||
|
@ -72,7 +73,7 @@ func Test_CSRF_Next(t *testing.T) {
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
@ -92,7 +93,7 @@ func Test_CSRF_Invalid_KeyLookup(t *testing.T) {
|
||||||
|
|
||||||
h := app.Handler()
|
h := app.Handler()
|
||||||
ctx := &fasthttp.RequestCtx{}
|
ctx := &fasthttp.RequestCtx{}
|
||||||
ctx.Request.Header.SetMethod("GET")
|
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||||
h(ctx)
|
h(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -110,7 +111,7 @@ func Test_CSRF_From_Form(t *testing.T) {
|
||||||
ctx := &fasthttp.RequestCtx{}
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
|
||||||
// Invalid CSRF token
|
// Invalid CSRF token
|
||||||
ctx.Request.Header.SetMethod("POST")
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||||
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm)
|
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm)
|
||||||
h(ctx)
|
h(ctx)
|
||||||
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
|
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
|
||||||
|
@ -118,12 +119,12 @@ func Test_CSRF_From_Form(t *testing.T) {
|
||||||
// Generate CSRF token
|
// Generate CSRF token
|
||||||
ctx.Request.Reset()
|
ctx.Request.Reset()
|
||||||
ctx.Response.Reset()
|
ctx.Response.Reset()
|
||||||
ctx.Request.Header.SetMethod("GET")
|
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||||
h(ctx)
|
h(ctx)
|
||||||
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
|
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
|
||||||
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
|
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
|
||||||
|
|
||||||
ctx.Request.Header.SetMethod("POST")
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||||
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm)
|
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm)
|
||||||
ctx.Request.SetBodyString("_csrf=" + token)
|
ctx.Request.SetBodyString("_csrf=" + token)
|
||||||
h(ctx)
|
h(ctx)
|
||||||
|
@ -144,7 +145,7 @@ func Test_CSRF_From_Query(t *testing.T) {
|
||||||
ctx := &fasthttp.RequestCtx{}
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
|
||||||
// Invalid CSRF token
|
// Invalid CSRF token
|
||||||
ctx.Request.Header.SetMethod("POST")
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||||
ctx.Request.SetRequestURI("/?_csrf=" + utils.UUID())
|
ctx.Request.SetRequestURI("/?_csrf=" + utils.UUID())
|
||||||
h(ctx)
|
h(ctx)
|
||||||
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
|
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
|
||||||
|
@ -152,7 +153,7 @@ func Test_CSRF_From_Query(t *testing.T) {
|
||||||
// Generate CSRF token
|
// Generate CSRF token
|
||||||
ctx.Request.Reset()
|
ctx.Request.Reset()
|
||||||
ctx.Response.Reset()
|
ctx.Response.Reset()
|
||||||
ctx.Request.Header.SetMethod("GET")
|
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||||
ctx.Request.SetRequestURI("/")
|
ctx.Request.SetRequestURI("/")
|
||||||
h(ctx)
|
h(ctx)
|
||||||
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
|
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
|
||||||
|
@ -161,7 +162,7 @@ func Test_CSRF_From_Query(t *testing.T) {
|
||||||
ctx.Request.Reset()
|
ctx.Request.Reset()
|
||||||
ctx.Response.Reset()
|
ctx.Response.Reset()
|
||||||
ctx.Request.SetRequestURI("/?_csrf=" + token)
|
ctx.Request.SetRequestURI("/?_csrf=" + token)
|
||||||
ctx.Request.Header.SetMethod("POST")
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||||
h(ctx)
|
h(ctx)
|
||||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||||
utils.AssertEqual(t, "OK", string(ctx.Response.Body()))
|
utils.AssertEqual(t, "OK", string(ctx.Response.Body()))
|
||||||
|
@ -181,7 +182,7 @@ func Test_CSRF_From_Param(t *testing.T) {
|
||||||
ctx := &fasthttp.RequestCtx{}
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
|
||||||
// Invalid CSRF token
|
// Invalid CSRF token
|
||||||
ctx.Request.Header.SetMethod("POST")
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||||
ctx.Request.SetRequestURI("/" + utils.UUID())
|
ctx.Request.SetRequestURI("/" + utils.UUID())
|
||||||
h(ctx)
|
h(ctx)
|
||||||
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
|
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
|
||||||
|
@ -189,7 +190,7 @@ func Test_CSRF_From_Param(t *testing.T) {
|
||||||
// Generate CSRF token
|
// Generate CSRF token
|
||||||
ctx.Request.Reset()
|
ctx.Request.Reset()
|
||||||
ctx.Response.Reset()
|
ctx.Response.Reset()
|
||||||
ctx.Request.Header.SetMethod("GET")
|
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||||
ctx.Request.SetRequestURI("/" + utils.UUID())
|
ctx.Request.SetRequestURI("/" + utils.UUID())
|
||||||
h(ctx)
|
h(ctx)
|
||||||
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
|
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
|
||||||
|
@ -198,7 +199,7 @@ func Test_CSRF_From_Param(t *testing.T) {
|
||||||
ctx.Request.Reset()
|
ctx.Request.Reset()
|
||||||
ctx.Response.Reset()
|
ctx.Response.Reset()
|
||||||
ctx.Request.SetRequestURI("/" + token)
|
ctx.Request.SetRequestURI("/" + token)
|
||||||
ctx.Request.Header.SetMethod("POST")
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||||
h(ctx)
|
h(ctx)
|
||||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||||
utils.AssertEqual(t, "OK", string(ctx.Response.Body()))
|
utils.AssertEqual(t, "OK", string(ctx.Response.Body()))
|
||||||
|
@ -218,7 +219,7 @@ func Test_CSRF_From_Cookie(t *testing.T) {
|
||||||
ctx := &fasthttp.RequestCtx{}
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
|
||||||
// Invalid CSRF token
|
// Invalid CSRF token
|
||||||
ctx.Request.Header.SetMethod("POST")
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||||
ctx.Request.SetRequestURI("/")
|
ctx.Request.SetRequestURI("/")
|
||||||
ctx.Request.Header.Set(fiber.HeaderCookie, "csrf="+utils.UUID()+";")
|
ctx.Request.Header.Set(fiber.HeaderCookie, "csrf="+utils.UUID()+";")
|
||||||
h(ctx)
|
h(ctx)
|
||||||
|
@ -227,7 +228,7 @@ func Test_CSRF_From_Cookie(t *testing.T) {
|
||||||
// Generate CSRF token
|
// Generate CSRF token
|
||||||
ctx.Request.Reset()
|
ctx.Request.Reset()
|
||||||
ctx.Response.Reset()
|
ctx.Response.Reset()
|
||||||
ctx.Request.Header.SetMethod("GET")
|
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||||
ctx.Request.SetRequestURI("/")
|
ctx.Request.SetRequestURI("/")
|
||||||
h(ctx)
|
h(ctx)
|
||||||
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
|
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
|
||||||
|
@ -235,7 +236,7 @@ func Test_CSRF_From_Cookie(t *testing.T) {
|
||||||
|
|
||||||
ctx.Request.Reset()
|
ctx.Request.Reset()
|
||||||
ctx.Response.Reset()
|
ctx.Response.Reset()
|
||||||
ctx.Request.Header.SetMethod("POST")
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||||
ctx.Request.Header.Set(fiber.HeaderCookie, "csrf="+token+";")
|
ctx.Request.Header.Set(fiber.HeaderCookie, "csrf="+token+";")
|
||||||
ctx.Request.SetRequestURI("/")
|
ctx.Request.SetRequestURI("/")
|
||||||
h(ctx)
|
h(ctx)
|
||||||
|
@ -268,7 +269,7 @@ func Test_CSRF_From_Custom(t *testing.T) {
|
||||||
ctx := &fasthttp.RequestCtx{}
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
|
||||||
// Invalid CSRF token
|
// Invalid CSRF token
|
||||||
ctx.Request.Header.SetMethod("POST")
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||||
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain)
|
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain)
|
||||||
h(ctx)
|
h(ctx)
|
||||||
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
|
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
|
||||||
|
@ -276,12 +277,12 @@ func Test_CSRF_From_Custom(t *testing.T) {
|
||||||
// Generate CSRF token
|
// Generate CSRF token
|
||||||
ctx.Request.Reset()
|
ctx.Request.Reset()
|
||||||
ctx.Response.Reset()
|
ctx.Response.Reset()
|
||||||
ctx.Request.Header.SetMethod("GET")
|
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||||
h(ctx)
|
h(ctx)
|
||||||
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
|
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
|
||||||
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
|
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
|
||||||
|
|
||||||
ctx.Request.Header.SetMethod("POST")
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||||
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain)
|
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain)
|
||||||
ctx.Request.SetBodyString("_csrf=" + token)
|
ctx.Request.SetBodyString("_csrf=" + token)
|
||||||
h(ctx)
|
h(ctx)
|
||||||
|
@ -307,13 +308,13 @@ func Test_CSRF_ErrorHandler_InvalidToken(t *testing.T) {
|
||||||
ctx := &fasthttp.RequestCtx{}
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
|
||||||
// Generate CSRF token
|
// Generate CSRF token
|
||||||
ctx.Request.Header.SetMethod("GET")
|
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||||
h(ctx)
|
h(ctx)
|
||||||
|
|
||||||
// invalid CSRF token
|
// invalid CSRF token
|
||||||
ctx.Request.Reset()
|
ctx.Request.Reset()
|
||||||
ctx.Response.Reset()
|
ctx.Response.Reset()
|
||||||
ctx.Request.Header.SetMethod("POST")
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||||
ctx.Request.Header.Set(HeaderName, "johndoe")
|
ctx.Request.Header.Set(HeaderName, "johndoe")
|
||||||
h(ctx)
|
h(ctx)
|
||||||
utils.AssertEqual(t, 419, ctx.Response.StatusCode())
|
utils.AssertEqual(t, 419, ctx.Response.StatusCode())
|
||||||
|
@ -339,13 +340,13 @@ func Test_CSRF_ErrorHandler_EmptyToken(t *testing.T) {
|
||||||
ctx := &fasthttp.RequestCtx{}
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
|
||||||
// Generate CSRF token
|
// Generate CSRF token
|
||||||
ctx.Request.Header.SetMethod("GET")
|
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||||
h(ctx)
|
h(ctx)
|
||||||
|
|
||||||
// empty CSRF token
|
// empty CSRF token
|
||||||
ctx.Request.Reset()
|
ctx.Request.Reset()
|
||||||
ctx.Response.Reset()
|
ctx.Response.Reset()
|
||||||
ctx.Request.Header.SetMethod("POST")
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||||
h(ctx)
|
h(ctx)
|
||||||
utils.AssertEqual(t, 419, ctx.Response.StatusCode())
|
utils.AssertEqual(t, 419, ctx.Response.StatusCode())
|
||||||
utils.AssertEqual(t, "empty CSRF token", string(ctx.Response.Body()))
|
utils.AssertEqual(t, "empty CSRF token", string(ctx.Response.Body()))
|
||||||
|
@ -355,7 +356,7 @@ func Test_CSRF_ErrorHandler_EmptyToken(t *testing.T) {
|
||||||
// func Test_CSRF_UnsafeHeaderValue(t *testing.T) {
|
// func Test_CSRF_UnsafeHeaderValue(t *testing.T) {
|
||||||
// t.Parallel()
|
// t.Parallel()
|
||||||
// app := fiber.New()
|
// app := fiber.New()
|
||||||
//
|
|
||||||
// app.Use(New())
|
// app.Use(New())
|
||||||
// app.Get("/", func(c *fiber.Ctx) error {
|
// app.Get("/", func(c *fiber.Ctx) error {
|
||||||
// return c.SendStatus(fiber.StatusOK)
|
// return c.SendStatus(fiber.StatusOK)
|
||||||
|
@ -366,11 +367,11 @@ func Test_CSRF_ErrorHandler_EmptyToken(t *testing.T) {
|
||||||
// app.Post("/", func(c *fiber.Ctx) error {
|
// app.Post("/", func(c *fiber.Ctx) error {
|
||||||
// return c.SendStatus(fiber.StatusOK)
|
// return c.SendStatus(fiber.StatusOK)
|
||||||
// })
|
// })
|
||||||
//
|
|
||||||
// resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
|
// resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
// utils.AssertEqual(t, nil, err)
|
// utils.AssertEqual(t, nil, err)
|
||||||
// utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
// utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||||
//
|
|
||||||
// var token string
|
// var token string
|
||||||
// for _, c := range resp.Cookies() {
|
// for _, c := range resp.Cookies() {
|
||||||
// if c.Name != ConfigDefault.CookieName {
|
// if c.Name != ConfigDefault.CookieName {
|
||||||
|
@ -379,25 +380,25 @@ func Test_CSRF_ErrorHandler_EmptyToken(t *testing.T) {
|
||||||
// token = c.Value
|
// token = c.Value
|
||||||
// break
|
// break
|
||||||
// }
|
// }
|
||||||
//
|
|
||||||
// fmt.Println("token", token)
|
// fmt.Println("token", token)
|
||||||
//
|
|
||||||
// getReq := httptest.NewRequest(http.MethodGet, "/", nil)
|
// getReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||||
// getReq.Header.Set(HeaderName, token)
|
// getReq.Header.Set(HeaderName, token)
|
||||||
// resp, err = app.Test(getReq)
|
// resp, err = app.Test(getReq)
|
||||||
//
|
|
||||||
// getReq = httptest.NewRequest(http.MethodGet, "/test", nil)
|
// getReq = httptest.NewRequest(fiber.MethodGet, "/test", nil)
|
||||||
// getReq.Header.Set("X-Requested-With", "XMLHttpRequest")
|
// getReq.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||||
// getReq.Header.Set(fiber.HeaderCacheControl, "no")
|
// getReq.Header.Set(fiber.HeaderCacheControl, "no")
|
||||||
// getReq.Header.Set(HeaderName, token)
|
// getReq.Header.Set(HeaderName, token)
|
||||||
//
|
|
||||||
// resp, err = app.Test(getReq)
|
// resp, err = app.Test(getReq)
|
||||||
//
|
|
||||||
// getReq.Header.Set(fiber.HeaderAccept, "*/*")
|
// getReq.Header.Set(fiber.HeaderAccept, "*/*")
|
||||||
// getReq.Header.Del(HeaderName)
|
// getReq.Header.Del(HeaderName)
|
||||||
// resp, err = app.Test(getReq)
|
// resp, err = app.Test(getReq)
|
||||||
//
|
|
||||||
// postReq := httptest.NewRequest(http.MethodPost, "/", nil)
|
// postReq := httptest.NewRequest(fiber.MethodPost, "/", nil)
|
||||||
// postReq.Header.Set("X-Requested-With", "XMLHttpRequest")
|
// postReq.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||||
// postReq.Header.Set(HeaderName, token)
|
// postReq.Header.Set(HeaderName, token)
|
||||||
// resp, err = app.Test(postReq)
|
// resp, err = app.Test(postReq)
|
||||||
|
@ -417,12 +418,12 @@ func Benchmark_Middleware_CSRF_Check(b *testing.B) {
|
||||||
ctx := &fasthttp.RequestCtx{}
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
|
||||||
// Generate CSRF token
|
// Generate CSRF token
|
||||||
ctx.Request.Header.SetMethod("GET")
|
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||||
h(ctx)
|
h(ctx)
|
||||||
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
|
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
|
||||||
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
|
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
|
||||||
|
|
||||||
ctx.Request.Header.SetMethod("POST")
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||||
ctx.Request.Header.Set(HeaderName, token)
|
ctx.Request.Header.Set(HeaderName, token)
|
||||||
|
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
|
@ -449,7 +450,7 @@ func Benchmark_Middleware_CSRF_GenerateToken(b *testing.B) {
|
||||||
ctx := &fasthttp.RequestCtx{}
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
|
||||||
// Generate CSRF token
|
// Generate CSRF token
|
||||||
ctx.Request.Header.SetMethod("GET")
|
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
|
||||||
|
|
|
@ -41,74 +41,23 @@ func newManager(storage fiber.Storage) *manager {
|
||||||
return manager
|
return manager
|
||||||
}
|
}
|
||||||
|
|
||||||
// acquire returns an *entry from the sync.Pool
|
|
||||||
func (m *manager) acquire() *item {
|
|
||||||
return m.pool.Get().(*item)
|
|
||||||
}
|
|
||||||
|
|
||||||
// release and reset *entry to sync.Pool
|
|
||||||
func (m *manager) release(e *item) {
|
|
||||||
// don't release item if we using memory storage
|
|
||||||
if m.storage != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
m.pool.Put(e)
|
|
||||||
}
|
|
||||||
|
|
||||||
// get data from storage or memory
|
|
||||||
func (m *manager) get(key string) (it *item) {
|
|
||||||
if m.storage != nil {
|
|
||||||
it = m.acquire()
|
|
||||||
if raw, _ := m.storage.Get(key); raw != nil {
|
|
||||||
if _, err := it.UnmarshalMsg(raw); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if it, _ = m.memory.Get(key).(*item); it == nil {
|
|
||||||
it = m.acquire()
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// get raw data from storage or memory
|
// get raw data from storage or memory
|
||||||
func (m *manager) getRaw(key string) (raw []byte) {
|
func (m *manager) getRaw(key string) []byte {
|
||||||
|
var raw []byte
|
||||||
if m.storage != nil {
|
if m.storage != nil {
|
||||||
raw, _ = m.storage.Get(key)
|
raw, _ = m.storage.Get(key) //nolint:errcheck // TODO: Do not ignore error
|
||||||
} else {
|
} else {
|
||||||
raw, _ = m.memory.Get(key).([]byte)
|
raw, _ = m.memory.Get(key).([]byte) //nolint:errcheck // TODO: Do not ignore error
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// set data to storage or memory
|
|
||||||
func (m *manager) set(key string, it *item, exp time.Duration) {
|
|
||||||
if m.storage != nil {
|
|
||||||
if raw, err := it.MarshalMsg(nil); err == nil {
|
|
||||||
_ = m.storage.Set(key, raw, exp)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here
|
|
||||||
m.memory.Set(utils.CopyString(key), it, exp)
|
|
||||||
}
|
}
|
||||||
|
return raw
|
||||||
}
|
}
|
||||||
|
|
||||||
// set data to storage or memory
|
// set data to storage or memory
|
||||||
func (m *manager) setRaw(key string, raw []byte, exp time.Duration) {
|
func (m *manager) setRaw(key string, raw []byte, exp time.Duration) {
|
||||||
if m.storage != nil {
|
if m.storage != nil {
|
||||||
_ = m.storage.Set(key, raw, exp)
|
_ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Do not ignore error
|
||||||
} else {
|
} else {
|
||||||
// the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here
|
// the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here
|
||||||
m.memory.Set(utils.CopyString(key), raw, exp)
|
m.memory.Set(utils.CopyString(key), raw, exp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// delete data from storage or memory
|
|
||||||
func (m *manager) delete(key string) {
|
|
||||||
if m.storage != nil {
|
|
||||||
_ = m.storage.Delete(key)
|
|
||||||
} else {
|
|
||||||
m.memory.Delete(key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package encryptcookie
|
package encryptcookie
|
||||||
|
|
||||||
import "github.com/gofiber/fiber/v2"
|
import (
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
)
|
||||||
|
|
||||||
// Config defines the config for middleware.
|
// Config defines the config for middleware.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
|
@ -32,6 +34,8 @@ type Config struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConfigDefault is the default config
|
// ConfigDefault is the default config
|
||||||
|
//
|
||||||
|
//nolint:gochecknoglobals // Using a global var is fine here
|
||||||
var ConfigDefault = Config{
|
var ConfigDefault = Config{
|
||||||
Next: nil,
|
Next: nil,
|
||||||
Except: []string{"csrf_"},
|
Except: []string{"csrf_"},
|
||||||
|
|
|
@ -2,6 +2,7 @@ package encryptcookie
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -7,9 +7,11 @@ import (
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/gofiber/fiber/v2/utils"
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
//nolint:gochecknoglobals // Using a global var is fine here
|
||||||
var testKey = GenerateKey()
|
var testKey = GenerateKey()
|
||||||
|
|
||||||
func Test_Middleware_Encrypt_Cookie(t *testing.T) {
|
func Test_Middleware_Encrypt_Cookie(t *testing.T) {
|
||||||
|
@ -35,14 +37,14 @@ func Test_Middleware_Encrypt_Cookie(t *testing.T) {
|
||||||
|
|
||||||
// Test empty cookie
|
// Test empty cookie
|
||||||
ctx := &fasthttp.RequestCtx{}
|
ctx := &fasthttp.RequestCtx{}
|
||||||
ctx.Request.Header.SetMethod("GET")
|
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||||
h(ctx)
|
h(ctx)
|
||||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||||
utils.AssertEqual(t, "value=", string(ctx.Response.Body()))
|
utils.AssertEqual(t, "value=", string(ctx.Response.Body()))
|
||||||
|
|
||||||
// Test invalid cookie
|
// Test invalid cookie
|
||||||
ctx = &fasthttp.RequestCtx{}
|
ctx = &fasthttp.RequestCtx{}
|
||||||
ctx.Request.Header.SetMethod("GET")
|
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||||
ctx.Request.Header.SetCookie("test", "Invalid")
|
ctx.Request.Header.SetCookie("test", "Invalid")
|
||||||
h(ctx)
|
h(ctx)
|
||||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||||
|
@ -54,18 +56,19 @@ func Test_Middleware_Encrypt_Cookie(t *testing.T) {
|
||||||
|
|
||||||
// Test valid cookie
|
// Test valid cookie
|
||||||
ctx = &fasthttp.RequestCtx{}
|
ctx = &fasthttp.RequestCtx{}
|
||||||
ctx.Request.Header.SetMethod("POST")
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||||
h(ctx)
|
h(ctx)
|
||||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||||
|
|
||||||
encryptedCookie := fasthttp.Cookie{}
|
encryptedCookie := fasthttp.Cookie{}
|
||||||
encryptedCookie.SetKey("test")
|
encryptedCookie.SetKey("test")
|
||||||
utils.AssertEqual(t, true, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value")
|
utils.AssertEqual(t, true, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value")
|
||||||
decryptedCookieValue, _ := DecryptCookie(string(encryptedCookie.Value()), testKey)
|
decryptedCookieValue, err := DecryptCookie(string(encryptedCookie.Value()), testKey)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, "SomeThing", decryptedCookieValue)
|
utils.AssertEqual(t, "SomeThing", decryptedCookieValue)
|
||||||
|
|
||||||
ctx = &fasthttp.RequestCtx{}
|
ctx = &fasthttp.RequestCtx{}
|
||||||
ctx.Request.Header.SetMethod("GET")
|
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||||
ctx.Request.Header.SetCookie("test", string(encryptedCookie.Value()))
|
ctx.Request.Header.SetCookie("test", string(encryptedCookie.Value()))
|
||||||
h(ctx)
|
h(ctx)
|
||||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||||
|
@ -91,7 +94,7 @@ func Test_Encrypt_Cookie_Next(t *testing.T) {
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, "SomeThing", resp.Cookies()[0].Value)
|
utils.AssertEqual(t, "SomeThing", resp.Cookies()[0].Value)
|
||||||
}
|
}
|
||||||
|
@ -123,7 +126,7 @@ func Test_Encrypt_Cookie_Except(t *testing.T) {
|
||||||
h := app.Handler()
|
h := app.Handler()
|
||||||
|
|
||||||
ctx := &fasthttp.RequestCtx{}
|
ctx := &fasthttp.RequestCtx{}
|
||||||
ctx.Request.Header.SetMethod("GET")
|
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||||
h(ctx)
|
h(ctx)
|
||||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||||
|
|
||||||
|
@ -135,7 +138,8 @@ func Test_Encrypt_Cookie_Except(t *testing.T) {
|
||||||
encryptedCookie := fasthttp.Cookie{}
|
encryptedCookie := fasthttp.Cookie{}
|
||||||
encryptedCookie.SetKey("test2")
|
encryptedCookie.SetKey("test2")
|
||||||
utils.AssertEqual(t, true, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value")
|
utils.AssertEqual(t, true, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value")
|
||||||
decryptedCookieValue, _ := DecryptCookie(string(encryptedCookie.Value()), testKey)
|
decryptedCookieValue, err := DecryptCookie(string(encryptedCookie.Value()), testKey)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, "SomeThing", decryptedCookieValue)
|
utils.AssertEqual(t, "SomeThing", decryptedCookieValue)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -169,18 +173,19 @@ func Test_Encrypt_Cookie_Custom_Encryptor(t *testing.T) {
|
||||||
h := app.Handler()
|
h := app.Handler()
|
||||||
|
|
||||||
ctx := &fasthttp.RequestCtx{}
|
ctx := &fasthttp.RequestCtx{}
|
||||||
ctx.Request.Header.SetMethod("POST")
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||||
h(ctx)
|
h(ctx)
|
||||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||||
|
|
||||||
encryptedCookie := fasthttp.Cookie{}
|
encryptedCookie := fasthttp.Cookie{}
|
||||||
encryptedCookie.SetKey("test")
|
encryptedCookie.SetKey("test")
|
||||||
utils.AssertEqual(t, true, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value")
|
utils.AssertEqual(t, true, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value")
|
||||||
decodedBytes, _ := base64.StdEncoding.DecodeString(string(encryptedCookie.Value()))
|
decodedBytes, err := base64.StdEncoding.DecodeString(string(encryptedCookie.Value()))
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, "SomeThing", string(decodedBytes))
|
utils.AssertEqual(t, "SomeThing", string(decodedBytes))
|
||||||
|
|
||||||
ctx = &fasthttp.RequestCtx{}
|
ctx = &fasthttp.RequestCtx{}
|
||||||
ctx.Request.Header.SetMethod("GET")
|
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||||
ctx.Request.Header.SetCookie("test", string(encryptedCookie.Value()))
|
ctx.Request.Header.SetCookie("test", string(encryptedCookie.Value()))
|
||||||
h(ctx)
|
h(ctx)
|
||||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||||
|
|
|
@ -6,47 +6,56 @@ import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
)
|
)
|
||||||
|
|
||||||
// EncryptCookie Encrypts a cookie value with specific encryption key
|
// EncryptCookie Encrypts a cookie value with specific encryption key
|
||||||
func EncryptCookie(value, key string) (string, error) {
|
func EncryptCookie(value, key string) (string, error) {
|
||||||
keyDecoded, _ := base64.StdEncoding.DecodeString(key)
|
keyDecoded, err := base64.StdEncoding.DecodeString(key)
|
||||||
plaintext := []byte(value)
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to base64-decode key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
block, err := aes.NewCipher(keyDecoded)
|
block, err := aes.NewCipher(keyDecoded)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", fmt.Errorf("failed to create AES cipher: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
gcm, err := cipher.NewGCM(block)
|
gcm, err := cipher.NewGCM(block)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", fmt.Errorf("failed to create GCM mode: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
nonce := make([]byte, gcm.NonceSize())
|
nonce := make([]byte, gcm.NonceSize())
|
||||||
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
|
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
|
||||||
return "", err
|
return "", fmt.Errorf("failed to read: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ciphertext := gcm.Seal(nonce, nonce, plaintext, nil)
|
ciphertext := gcm.Seal(nonce, nonce, []byte(value), nil)
|
||||||
|
|
||||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DecryptCookie Decrypts a cookie value with specific encryption key
|
// DecryptCookie Decrypts a cookie value with specific encryption key
|
||||||
func DecryptCookie(value, key string) (string, error) {
|
func DecryptCookie(value, key string) (string, error) {
|
||||||
keyDecoded, _ := base64.StdEncoding.DecodeString(key)
|
keyDecoded, err := base64.StdEncoding.DecodeString(key)
|
||||||
enc, _ := base64.StdEncoding.DecodeString(value)
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to base64-decode key: %w", err)
|
||||||
|
}
|
||||||
|
enc, err := base64.StdEncoding.DecodeString(value)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to base64-decode value: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
block, err := aes.NewCipher(keyDecoded)
|
block, err := aes.NewCipher(keyDecoded)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", fmt.Errorf("failed to create AES cipher: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
gcm, err := cipher.NewGCM(block)
|
gcm, err := cipher.NewGCM(block)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", fmt.Errorf("failed to create GCM mode: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
nonceSize := gcm.NonceSize()
|
nonceSize := gcm.NonceSize()
|
||||||
|
@ -59,7 +68,7 @@ func DecryptCookie(value, key string) (string, error) {
|
||||||
|
|
||||||
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", fmt.Errorf("failed to decrypt ciphertext: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return string(plaintext), nil
|
return string(plaintext), nil
|
||||||
|
@ -67,7 +76,8 @@ func DecryptCookie(value, key string) (string, error) {
|
||||||
|
|
||||||
// GenerateKey Generates an encryption key
|
// GenerateKey Generates an encryption key
|
||||||
func GenerateKey() string {
|
func GenerateKey() string {
|
||||||
ret := make([]byte, 32)
|
const keyLen = 32
|
||||||
|
ret := make([]byte, keyLen)
|
||||||
|
|
||||||
if _, err := rand.Read(ret); err != nil {
|
if _, err := rand.Read(ret); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
|
|
|
@ -23,10 +23,8 @@ func (envVar *EnvVar) set(key, val string) {
|
||||||
envVar.Vars[key] = val
|
envVar.Vars[key] = val
|
||||||
}
|
}
|
||||||
|
|
||||||
var defaultConfig = Config{}
|
|
||||||
|
|
||||||
func New(config ...Config) fiber.Handler {
|
func New(config ...Config) fiber.Handler {
|
||||||
var cfg = defaultConfig
|
var cfg Config
|
||||||
if len(config) > 0 {
|
if len(config) > 0 {
|
||||||
cfg = config[0]
|
cfg = config[0]
|
||||||
}
|
}
|
||||||
|
@ -57,8 +55,9 @@ func newEnvVar(cfg Config) *EnvVar {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
const numElems = 2
|
||||||
for _, envVal := range os.Environ() {
|
for _, envVal := range os.Environ() {
|
||||||
keyVal := strings.SplitN(envVal, "=", 2)
|
keyVal := strings.SplitN(envVal, "=", numElems)
|
||||||
if _, exists := cfg.ExcludeVars[keyVal[0]]; !exists {
|
if _, exists := cfg.ExcludeVars[keyVal[0]]; !exists {
|
||||||
vars.set(keyVal[0], keyVal[1])
|
vars.set(keyVal[0], keyVal[1])
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
|
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
|
||||||
package envvar
|
package envvar
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -12,16 +14,25 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestEnvVarStructWithExportVarsExcludeVars(t *testing.T) {
|
func TestEnvVarStructWithExportVarsExcludeVars(t *testing.T) {
|
||||||
os.Setenv("testKey", "testEnvValue")
|
err := os.Setenv("testKey", "testEnvValue")
|
||||||
os.Setenv("anotherEnvKey", "anotherEnvVal")
|
utils.AssertEqual(t, nil, err)
|
||||||
os.Setenv("excludeKey", "excludeEnvValue")
|
err = os.Setenv("anotherEnvKey", "anotherEnvVal")
|
||||||
defer os.Unsetenv("testKey")
|
utils.AssertEqual(t, nil, err)
|
||||||
defer os.Unsetenv("anotherEnvKey")
|
err = os.Setenv("excludeKey", "excludeEnvValue")
|
||||||
defer os.Unsetenv("excludeKey")
|
utils.AssertEqual(t, nil, err)
|
||||||
|
defer func() {
|
||||||
|
err := os.Unsetenv("testKey")
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
err = os.Unsetenv("anotherEnvKey")
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
err = os.Unsetenv("excludeKey")
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
}()
|
||||||
|
|
||||||
vars := newEnvVar(Config{
|
vars := newEnvVar(Config{
|
||||||
ExportVars: map[string]string{"testKey": "", "testDefaultKey": "testDefaultVal"},
|
ExportVars: map[string]string{"testKey": "", "testDefaultKey": "testDefaultVal"},
|
||||||
ExcludeVars: map[string]string{"excludeKey": ""}})
|
ExcludeVars: map[string]string{"excludeKey": ""},
|
||||||
|
})
|
||||||
|
|
||||||
utils.AssertEqual(t, vars.Vars["testKey"], "testEnvValue")
|
utils.AssertEqual(t, vars.Vars["testKey"], "testEnvValue")
|
||||||
utils.AssertEqual(t, vars.Vars["testDefaultKey"], "testDefaultVal")
|
utils.AssertEqual(t, vars.Vars["testDefaultKey"], "testDefaultVal")
|
||||||
|
@ -30,21 +41,28 @@ func TestEnvVarStructWithExportVarsExcludeVars(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEnvVarHandler(t *testing.T) {
|
func TestEnvVarHandler(t *testing.T) {
|
||||||
os.Setenv("testKey", "testVal")
|
err := os.Setenv("testKey", "testVal")
|
||||||
defer os.Unsetenv("testKey")
|
utils.AssertEqual(t, nil, err)
|
||||||
|
defer func() {
|
||||||
|
err := os.Unsetenv("testKey")
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
}()
|
||||||
|
|
||||||
expectedEnvVarResponse, _ := json.Marshal(
|
expectedEnvVarResponse, err := json.Marshal(
|
||||||
struct {
|
struct {
|
||||||
Vars map[string]string `json:"vars"`
|
Vars map[string]string `json:"vars"`
|
||||||
}{
|
}{
|
||||||
map[string]string{"testKey": "testVal"},
|
map[string]string{"testKey": "testVal"},
|
||||||
})
|
})
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
app := fiber.New()
|
app := fiber.New()
|
||||||
app.Use("/envvars", New(Config{
|
app.Use("/envvars", New(Config{
|
||||||
ExportVars: map[string]string{"testKey": ""}}))
|
ExportVars: map[string]string{"testKey": ""},
|
||||||
|
}))
|
||||||
|
|
||||||
req, _ := http.NewRequest("GET", "http://localhost/envvars", nil)
|
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars", nil)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
resp, err := app.Test(req)
|
resp, err := app.Test(req)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
|
@ -57,14 +75,16 @@ func TestEnvVarHandler(t *testing.T) {
|
||||||
func TestEnvVarHandlerNotMatched(t *testing.T) {
|
func TestEnvVarHandlerNotMatched(t *testing.T) {
|
||||||
app := fiber.New()
|
app := fiber.New()
|
||||||
app.Use("/envvars", New(Config{
|
app.Use("/envvars", New(Config{
|
||||||
ExportVars: map[string]string{"testKey": ""}}))
|
ExportVars: map[string]string{"testKey": ""},
|
||||||
|
}))
|
||||||
|
|
||||||
app.Get("/another-path", func(ctx *fiber.Ctx) error {
|
app.Get("/another-path", func(ctx *fiber.Ctx) error {
|
||||||
utils.AssertEqual(t, nil, ctx.SendString("OK"))
|
utils.AssertEqual(t, nil, ctx.SendString("OK"))
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
req, _ := http.NewRequest("GET", "http://localhost/another-path", nil)
|
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/another-path", nil)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
resp, err := app.Test(req)
|
resp, err := app.Test(req)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
|
@ -75,13 +95,18 @@ func TestEnvVarHandlerNotMatched(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEnvVarHandlerDefaultConfig(t *testing.T) {
|
func TestEnvVarHandlerDefaultConfig(t *testing.T) {
|
||||||
os.Setenv("testEnvKey", "testEnvVal")
|
err := os.Setenv("testEnvKey", "testEnvVal")
|
||||||
defer os.Unsetenv("testEnvKey")
|
utils.AssertEqual(t, nil, err)
|
||||||
|
defer func() {
|
||||||
|
err := os.Unsetenv("testEnvKey")
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
}()
|
||||||
|
|
||||||
app := fiber.New()
|
app := fiber.New()
|
||||||
app.Use("/envvars", New())
|
app.Use("/envvars", New())
|
||||||
|
|
||||||
req, _ := http.NewRequest("GET", "http://localhost/envvars", nil)
|
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars", nil)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
resp, err := app.Test(req)
|
resp, err := app.Test(req)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
|
@ -98,7 +123,8 @@ func TestEnvVarHandlerMethod(t *testing.T) {
|
||||||
app := fiber.New()
|
app := fiber.New()
|
||||||
app.Use("/envvars", New())
|
app.Use("/envvars", New())
|
||||||
|
|
||||||
req, _ := http.NewRequest("POST", "http://localhost/envvars", nil)
|
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodPost, "http://localhost/envvars", nil)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
resp, err := app.Test(req)
|
resp, err := app.Test(req)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusMethodNotAllowed, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusMethodNotAllowed, resp.StatusCode)
|
||||||
|
@ -107,14 +133,19 @@ func TestEnvVarHandlerMethod(t *testing.T) {
|
||||||
func TestEnvVarHandlerSpecialValue(t *testing.T) {
|
func TestEnvVarHandlerSpecialValue(t *testing.T) {
|
||||||
testEnvKey := "testEnvKey"
|
testEnvKey := "testEnvKey"
|
||||||
fakeBase64 := "testBase64:TQ=="
|
fakeBase64 := "testBase64:TQ=="
|
||||||
os.Setenv(testEnvKey, fakeBase64)
|
err := os.Setenv(testEnvKey, fakeBase64)
|
||||||
defer os.Unsetenv(testEnvKey)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
defer func() {
|
||||||
|
err := os.Unsetenv(testEnvKey)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
}()
|
||||||
|
|
||||||
app := fiber.New()
|
app := fiber.New()
|
||||||
app.Use("/envvars", New())
|
app.Use("/envvars", New())
|
||||||
app.Use("/envvars/export", New(Config{ExportVars: map[string]string{testEnvKey: ""}}))
|
app.Use("/envvars/export", New(Config{ExportVars: map[string]string{testEnvKey: ""}}))
|
||||||
|
|
||||||
req, _ := http.NewRequest("GET", "http://localhost/envvars", nil)
|
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars", nil)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
resp, err := app.Test(req)
|
resp, err := app.Test(req)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
|
@ -126,7 +157,8 @@ func TestEnvVarHandlerSpecialValue(t *testing.T) {
|
||||||
val := envVars.Vars[testEnvKey]
|
val := envVars.Vars[testEnvKey]
|
||||||
utils.AssertEqual(t, fakeBase64, val)
|
utils.AssertEqual(t, fakeBase64, val)
|
||||||
|
|
||||||
req, _ = http.NewRequest("GET", "http://localhost/envvars/export", nil)
|
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars/export", nil)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
resp, err = app.Test(req)
|
resp, err = app.Test(req)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,8 @@ type Config struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConfigDefault is the default config
|
// ConfigDefault is the default config
|
||||||
|
//
|
||||||
|
//nolint:gochecknoglobals // Using a global var is fine here
|
||||||
var ConfigDefault = Config{
|
var ConfigDefault = Config{
|
||||||
Weak: false,
|
Weak: false,
|
||||||
Next: nil,
|
Next: nil,
|
||||||
|
|
|
@ -5,12 +5,8 @@ import (
|
||||||
"hash/crc32"
|
"hash/crc32"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/valyala/bytebufferpool"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
"github.com/valyala/bytebufferpool"
|
||||||
normalizedHeaderETag = []byte("Etag")
|
|
||||||
weakPrefix = []byte("W/")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// New creates a new middleware handler
|
// New creates a new middleware handler
|
||||||
|
@ -18,32 +14,38 @@ func New(config ...Config) fiber.Handler {
|
||||||
// Set default config
|
// Set default config
|
||||||
cfg := configDefault(config...)
|
cfg := configDefault(config...)
|
||||||
|
|
||||||
crc32q := crc32.MakeTable(0xD5828281)
|
var (
|
||||||
|
normalizedHeaderETag = []byte("Etag")
|
||||||
|
weakPrefix = []byte("W/")
|
||||||
|
)
|
||||||
|
|
||||||
|
const crcPol = 0xD5828281
|
||||||
|
crc32q := crc32.MakeTable(crcPol)
|
||||||
|
|
||||||
// Return new handler
|
// Return new handler
|
||||||
return func(c *fiber.Ctx) (err error) {
|
return func(c *fiber.Ctx) error {
|
||||||
// Don't execute middleware if Next returns true
|
// Don't execute middleware if Next returns true
|
||||||
if cfg.Next != nil && cfg.Next(c) {
|
if cfg.Next != nil && cfg.Next(c) {
|
||||||
return c.Next()
|
return c.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return err if next handler returns one
|
// Return err if next handler returns one
|
||||||
if err = c.Next(); err != nil {
|
if err := c.Next(); err != nil {
|
||||||
return
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Don't generate ETags for invalid responses
|
// Don't generate ETags for invalid responses
|
||||||
if c.Response().StatusCode() != fiber.StatusOK {
|
if c.Response().StatusCode() != fiber.StatusOK {
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
body := c.Response().Body()
|
body := c.Response().Body()
|
||||||
// Skips ETag if no response body is present
|
// Skips ETag if no response body is present
|
||||||
if len(body) == 0 {
|
if len(body) == 0 {
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
// Skip ETag if header is already present
|
// Skip ETag if header is already present
|
||||||
if c.Response().Header.PeekBytes(normalizedHeaderETag) != nil {
|
if c.Response().Header.PeekBytes(normalizedHeaderETag) != nil {
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate ETag for response
|
// Generate ETag for response
|
||||||
|
@ -52,14 +54,14 @@ func New(config ...Config) fiber.Handler {
|
||||||
|
|
||||||
// Enable weak tag
|
// Enable weak tag
|
||||||
if cfg.Weak {
|
if cfg.Weak {
|
||||||
_, _ = bb.Write(weakPrefix)
|
_, _ = bb.Write(weakPrefix) //nolint:errcheck // This will never fail
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = bb.WriteByte('"')
|
_ = bb.WriteByte('"') //nolint:errcheck // This will never fail
|
||||||
bb.B = appendUint(bb.Bytes(), uint32(len(body)))
|
bb.B = appendUint(bb.Bytes(), uint32(len(body)))
|
||||||
_ = bb.WriteByte('-')
|
_ = bb.WriteByte('-') //nolint:errcheck // This will never fail
|
||||||
bb.B = appendUint(bb.Bytes(), crc32.Checksum(body, crc32q))
|
bb.B = appendUint(bb.Bytes(), crc32.Checksum(body, crc32q))
|
||||||
_ = bb.WriteByte('"')
|
_ = bb.WriteByte('"') //nolint:errcheck // This will never fail
|
||||||
|
|
||||||
etag := bb.Bytes()
|
etag := bb.Bytes()
|
||||||
|
|
||||||
|
@ -78,7 +80,7 @@ func New(config ...Config) fiber.Handler {
|
||||||
// W/1 != W/2 || W/1 != 2
|
// W/1 != W/2 || W/1 != 2
|
||||||
c.Response().Header.SetCanonical(normalizedHeaderETag, etag)
|
c.Response().Header.SetCanonical(normalizedHeaderETag, etag)
|
||||||
|
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if bytes.Contains(clientEtag, etag) {
|
if bytes.Contains(clientEtag, etag) {
|
||||||
|
@ -90,7 +92,7 @@ func New(config ...Config) fiber.Handler {
|
||||||
// 1 != 2
|
// 1 != 2
|
||||||
c.Response().Header.SetCanonical(normalizedHeaderETag, etag)
|
c.Response().Header.SetCanonical(normalizedHeaderETag, etag)
|
||||||
|
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -102,7 +104,7 @@ func appendUint(dst []byte, n uint32) []byte {
|
||||||
var q uint32
|
var q uint32
|
||||||
for n >= 10 {
|
for n >= 10 {
|
||||||
i--
|
i--
|
||||||
q = n / 10
|
q = n / 10 //nolint:gomnd // TODO: Explain why we divide by 10 here
|
||||||
buf[i] = '0' + byte(n-q*10)
|
buf[i] = '0' + byte(n-q*10)
|
||||||
n = q
|
n = q
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/gofiber/fiber/v2/utils"
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -21,7 +22,7 @@ func Test_ETag_Next(t *testing.T) {
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
@ -37,7 +38,7 @@ func Test_ETag_SkipError(t *testing.T) {
|
||||||
return fiber.ErrForbidden
|
return fiber.ErrForbidden
|
||||||
})
|
})
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusForbidden, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusForbidden, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
@ -53,7 +54,7 @@ func Test_ETag_NotStatusOK(t *testing.T) {
|
||||||
return c.SendStatus(fiber.StatusCreated)
|
return c.SendStatus(fiber.StatusCreated)
|
||||||
})
|
})
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusCreated, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusCreated, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
@ -69,7 +70,7 @@ func Test_ETag_NoBody(t *testing.T) {
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
@ -91,7 +92,7 @@ func Test_ETag_NewEtag(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func testETagNewEtag(t *testing.T, headerIfNoneMatch, matched bool) {
|
func testETagNewEtag(t *testing.T, headerIfNoneMatch, matched bool) { //nolint:revive // We're in a test, so using bools as a flow-control is fine
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
app := fiber.New()
|
app := fiber.New()
|
||||||
|
@ -102,7 +103,7 @@ func testETagNewEtag(t *testing.T, headerIfNoneMatch, matched bool) {
|
||||||
return c.SendString("Hello, World!")
|
return c.SendString("Hello, World!")
|
||||||
})
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||||
if headerIfNoneMatch {
|
if headerIfNoneMatch {
|
||||||
etag := `"non-match"`
|
etag := `"non-match"`
|
||||||
if matched {
|
if matched {
|
||||||
|
@ -145,7 +146,7 @@ func Test_ETag_WeakEtag(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func testETagWeakEtag(t *testing.T, headerIfNoneMatch, matched bool) {
|
func testETagWeakEtag(t *testing.T, headerIfNoneMatch, matched bool) { //nolint:revive // We're in a test, so using bools as a flow-control is fine
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
app := fiber.New()
|
app := fiber.New()
|
||||||
|
@ -156,7 +157,7 @@ func testETagWeakEtag(t *testing.T, headerIfNoneMatch, matched bool) {
|
||||||
return c.SendString("Hello, World!")
|
return c.SendString("Hello, World!")
|
||||||
})
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||||
if headerIfNoneMatch {
|
if headerIfNoneMatch {
|
||||||
etag := `W/"non-match"`
|
etag := `W/"non-match"`
|
||||||
if matched {
|
if matched {
|
||||||
|
@ -199,7 +200,7 @@ func Test_ETag_CustomEtag(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func testETagCustomEtag(t *testing.T, headerIfNoneMatch, matched bool) {
|
func testETagCustomEtag(t *testing.T, headerIfNoneMatch, matched bool) { //nolint:revive // We're in a test, so using bools as a flow-control is fine
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
app := fiber.New()
|
app := fiber.New()
|
||||||
|
@ -214,7 +215,7 @@ func testETagCustomEtag(t *testing.T, headerIfNoneMatch, matched bool) {
|
||||||
return c.SendString("Hello, World!")
|
return c.SendString("Hello, World!")
|
||||||
})
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||||
if headerIfNoneMatch {
|
if headerIfNoneMatch {
|
||||||
etag := `"non-match"`
|
etag := `"non-match"`
|
||||||
if matched {
|
if matched {
|
||||||
|
@ -255,7 +256,7 @@ func Test_ETag_CustomEtagPut(t *testing.T) {
|
||||||
return c.SendString("Hello, World!")
|
return c.SendString("Hello, World!")
|
||||||
})
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest("PUT", "/", nil)
|
req := httptest.NewRequest(fiber.MethodPut, "/", nil)
|
||||||
req.Header.Set(fiber.HeaderIfMatch, `"non-match"`)
|
req.Header.Set(fiber.HeaderIfMatch, `"non-match"`)
|
||||||
resp, err := app.Test(req)
|
resp, err := app.Test(req)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
@ -275,7 +276,7 @@ func Benchmark_Etag(b *testing.B) {
|
||||||
h := app.Handler()
|
h := app.Handler()
|
||||||
|
|
||||||
fctx := &fasthttp.RequestCtx{}
|
fctx := &fasthttp.RequestCtx{}
|
||||||
fctx.Request.Header.SetMethod("GET")
|
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||||
fctx.Request.SetRequestURI("/")
|
fctx.Request.SetRequestURI("/")
|
||||||
|
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package expvar
|
package expvar
|
||||||
|
|
||||||
import "github.com/gofiber/fiber/v2"
|
import (
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
)
|
||||||
|
|
||||||
// Config defines the config for middleware.
|
// Config defines the config for middleware.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
|
@ -10,6 +12,7 @@ type Config struct {
|
||||||
Next func(c *fiber.Ctx) bool
|
Next func(c *fiber.Ctx) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//nolint:gochecknoglobals // Using a global var is fine here
|
||||||
var ConfigDefault = Config{
|
var ConfigDefault = Config{
|
||||||
Next: nil,
|
Next: nil,
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp/expvarhandler"
|
"github.com/valyala/fasthttp/expvarhandler"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -29,6 +30,6 @@ func New(config ...Config) fiber.Handler {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.Redirect("/debug/vars", 302)
|
return c.Redirect("/debug/vars", fiber.StatusFound)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,6 +34,8 @@ type Config struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConfigDefault is the default config
|
// ConfigDefault is the default config
|
||||||
|
//
|
||||||
|
//nolint:gochecknoglobals // Using a global var is fine here
|
||||||
var ConfigDefault = Config{
|
var ConfigDefault = Config{
|
||||||
Next: nil,
|
Next: nil,
|
||||||
File: "",
|
File: "",
|
||||||
|
|
|
@ -1,16 +1,18 @@
|
||||||
|
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
|
||||||
package favicon
|
package favicon
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/gofiber/fiber/v2/utils"
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// go test -run Test_Middleware_Favicon
|
// go test -run Test_Middleware_Favicon
|
||||||
|
@ -25,22 +27,22 @@ func Test_Middleware_Favicon(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
// Skip Favicon middleware
|
// Skip Favicon middleware
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
||||||
|
|
||||||
resp, err = app.Test(httptest.NewRequest("GET", "/favicon.ico", nil))
|
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", nil))
|
||||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||||
utils.AssertEqual(t, fiber.StatusNoContent, resp.StatusCode, "Status code")
|
utils.AssertEqual(t, fiber.StatusNoContent, resp.StatusCode, "Status code")
|
||||||
|
|
||||||
resp, err = app.Test(httptest.NewRequest("OPTIONS", "/favicon.ico", nil))
|
resp, err = app.Test(httptest.NewRequest(fiber.MethodOptions, "/favicon.ico", nil))
|
||||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
||||||
|
|
||||||
resp, err = app.Test(httptest.NewRequest("PUT", "/favicon.ico", nil))
|
resp, err = app.Test(httptest.NewRequest(fiber.MethodPut, "/favicon.ico", nil))
|
||||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||||
utils.AssertEqual(t, fiber.StatusMethodNotAllowed, resp.StatusCode, "Status code")
|
utils.AssertEqual(t, fiber.StatusMethodNotAllowed, resp.StatusCode, "Status code")
|
||||||
utils.AssertEqual(t, "GET, HEAD, OPTIONS", resp.Header.Get(fiber.HeaderAllow))
|
utils.AssertEqual(t, strings.Join([]string{fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions}, ", "), resp.Header.Get(fiber.HeaderAllow))
|
||||||
}
|
}
|
||||||
|
|
||||||
// go test -run Test_Middleware_Favicon_Not_Found
|
// go test -run Test_Middleware_Favicon_Not_Found
|
||||||
|
@ -70,8 +72,7 @@ func Test_Middleware_Favicon_Found(t *testing.T) {
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/favicon.ico", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", nil))
|
||||||
|
|
||||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
||||||
utils.AssertEqual(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType))
|
utils.AssertEqual(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType))
|
||||||
|
@ -83,15 +84,15 @@ func Test_Middleware_Favicon_Found(t *testing.T) {
|
||||||
// TODO use os.Dir if fiber upgrades to 1.16
|
// TODO use os.Dir if fiber upgrades to 1.16
|
||||||
type mockFS struct{}
|
type mockFS struct{}
|
||||||
|
|
||||||
func (m mockFS) Open(name string) (http.File, error) {
|
func (mockFS) Open(name string) (http.File, error) {
|
||||||
if name == "/" {
|
if name == "/" {
|
||||||
name = "."
|
name = "."
|
||||||
} else {
|
} else {
|
||||||
name = strings.TrimPrefix(name, "/")
|
name = strings.TrimPrefix(name, "/")
|
||||||
}
|
}
|
||||||
file, err := os.Open(name)
|
file, err := os.Open(name) //nolint:gosec // We're in a test func, so this is fine
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to open: %w", err)
|
||||||
}
|
}
|
||||||
return file, nil
|
return file, nil
|
||||||
}
|
}
|
||||||
|
@ -106,7 +107,7 @@ func Test_Middleware_Favicon_FileSystem(t *testing.T) {
|
||||||
FileSystem: mockFS{},
|
FileSystem: mockFS{},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/favicon.ico", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", nil))
|
||||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
||||||
utils.AssertEqual(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType))
|
utils.AssertEqual(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType))
|
||||||
|
@ -123,7 +124,7 @@ func Test_Middleware_Favicon_CacheControl(t *testing.T) {
|
||||||
File: "../../.github/testdata/favicon.ico",
|
File: "../../.github/testdata/favicon.ico",
|
||||||
}))
|
}))
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/favicon.ico", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", nil))
|
||||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
||||||
utils.AssertEqual(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType))
|
utils.AssertEqual(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType))
|
||||||
|
@ -159,7 +160,7 @@ func Test_Favicon_Next(t *testing.T) {
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package filesystem
|
package filesystem
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
@ -55,6 +56,8 @@ type Config struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConfigDefault is the default config
|
// ConfigDefault is the default config
|
||||||
|
//
|
||||||
|
//nolint:gochecknoglobals // Using a global var is fine here
|
||||||
var ConfigDefault = Config{
|
var ConfigDefault = Config{
|
||||||
Next: nil,
|
Next: nil,
|
||||||
Root: nil,
|
Root: nil,
|
||||||
|
@ -102,7 +105,7 @@ func New(config ...Config) fiber.Handler {
|
||||||
cacheControlStr := "public, max-age=" + strconv.Itoa(cfg.MaxAge)
|
cacheControlStr := "public, max-age=" + strconv.Itoa(cfg.MaxAge)
|
||||||
|
|
||||||
// Return new handler
|
// Return new handler
|
||||||
return func(c *fiber.Ctx) (err error) {
|
return func(c *fiber.Ctx) error {
|
||||||
// Don't execute middleware if Next returns true
|
// Don't execute middleware if Next returns true
|
||||||
if cfg.Next != nil && cfg.Next(c) {
|
if cfg.Next != nil && cfg.Next(c) {
|
||||||
return c.Next()
|
return c.Next()
|
||||||
|
@ -131,28 +134,23 @@ func New(config ...Config) fiber.Handler {
|
||||||
path = cfg.PathPrefix + path
|
path = cfg.PathPrefix + path
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
file http.File
|
|
||||||
stat os.FileInfo
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(path) > 1 {
|
if len(path) > 1 {
|
||||||
path = utils.TrimRight(path, '/')
|
path = utils.TrimRight(path, '/')
|
||||||
}
|
}
|
||||||
file, err = cfg.Root.Open(path)
|
file, err := cfg.Root.Open(path)
|
||||||
if err != nil && os.IsNotExist(err) && cfg.NotFoundFile != "" {
|
if err != nil && os.IsNotExist(err) && cfg.NotFoundFile != "" {
|
||||||
file, err = cfg.Root.Open(cfg.NotFoundFile)
|
file, err = cfg.Root.Open(cfg.NotFoundFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if os.IsNotExist(err) {
|
if os.IsNotExist(err) {
|
||||||
return c.Status(fiber.StatusNotFound).Next()
|
return c.Status(fiber.StatusNotFound).Next()
|
||||||
}
|
}
|
||||||
return
|
return fmt.Errorf("failed to open: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if stat, err = file.Stat(); err != nil {
|
stat, err := file.Stat()
|
||||||
return
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to stat: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serve index if path is directory
|
// Serve index if path is directory
|
||||||
|
@ -200,7 +198,7 @@ func New(config ...Config) fiber.Handler {
|
||||||
c.Response().SkipBody = true
|
c.Response().SkipBody = true
|
||||||
c.Response().Header.SetContentLength(contentLength)
|
c.Response().Header.SetContentLength(contentLength)
|
||||||
if err := file.Close(); err != nil {
|
if err := file.Close(); err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to close: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -210,22 +208,18 @@ func New(config ...Config) fiber.Handler {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendFile ...
|
// SendFile ...
|
||||||
func SendFile(c *fiber.Ctx, fs http.FileSystem, path string) (err error) {
|
func SendFile(c *fiber.Ctx, fs http.FileSystem, path string) error {
|
||||||
var (
|
file, err := fs.Open(path)
|
||||||
file http.File
|
|
||||||
stat os.FileInfo
|
|
||||||
)
|
|
||||||
|
|
||||||
file, err = fs.Open(path)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if os.IsNotExist(err) {
|
if os.IsNotExist(err) {
|
||||||
return fiber.ErrNotFound
|
return fiber.ErrNotFound
|
||||||
}
|
}
|
||||||
return err
|
return fmt.Errorf("failed to open: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if stat, err = file.Stat(); err != nil {
|
stat, err := file.Stat()
|
||||||
return err
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to stat: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serve index if path is directory
|
// Serve index if path is directory
|
||||||
|
@ -268,7 +262,7 @@ func SendFile(c *fiber.Ctx, fs http.FileSystem, path string) (err error) {
|
||||||
c.Response().SkipBody = true
|
c.Response().SkipBody = true
|
||||||
c.Response().Header.SetContentLength(contentLength)
|
c.Response().Header.SetContentLength(contentLength)
|
||||||
if err := file.Close(); err != nil {
|
if err := file.Close(); err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to close: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
|
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
|
||||||
package filesystem
|
package filesystem
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -119,7 +121,7 @@ func Test_FileSystem(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", tt.url, nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, tt.url, nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, tt.statusCode, resp.StatusCode)
|
utils.AssertEqual(t, tt.statusCode, resp.StatusCode)
|
||||||
|
|
||||||
|
@ -142,7 +144,7 @@ func Test_FileSystem_Next(t *testing.T) {
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
@ -168,7 +170,8 @@ func Test_FileSystem_Head(t *testing.T) {
|
||||||
Root: http.Dir("../../.github/testdata/fs"),
|
Root: http.Dir("../../.github/testdata/fs"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
req, _ := http.NewRequest(fiber.MethodHead, "/test", nil)
|
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodHead, "/test", nil)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
resp, err := app.Test(req)
|
resp, err := app.Test(req)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||||
|
@ -182,7 +185,8 @@ func Test_FileSystem_NoRoot(t *testing.T) {
|
||||||
|
|
||||||
app := fiber.New()
|
app := fiber.New()
|
||||||
app.Use(New())
|
app.Use(New())
|
||||||
_, _ = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_FileSystem_UsingParam(t *testing.T) {
|
func Test_FileSystem_UsingParam(t *testing.T) {
|
||||||
|
@ -193,7 +197,8 @@ func Test_FileSystem_UsingParam(t *testing.T) {
|
||||||
return SendFile(c, http.Dir("../../.github/testdata/fs"), c.Params("path")+".html")
|
return SendFile(c, http.Dir("../../.github/testdata/fs"), c.Params("path")+".html")
|
||||||
})
|
})
|
||||||
|
|
||||||
req, _ := http.NewRequest(fiber.MethodHead, "/index", nil)
|
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodHead, "/index", nil)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
resp, err := app.Test(req)
|
resp, err := app.Test(req)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||||
|
@ -207,7 +212,8 @@ func Test_FileSystem_UsingParam_NonFile(t *testing.T) {
|
||||||
return SendFile(c, http.Dir("../../.github/testdata/fs"), c.Params("path")+".html")
|
return SendFile(c, http.Dir("../../.github/testdata/fs"), c.Params("path")+".html")
|
||||||
})
|
})
|
||||||
|
|
||||||
req, _ := http.NewRequest(fiber.MethodHead, "/template", nil)
|
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodHead, "/template", nil)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
resp, err := app.Test(req)
|
resp, err := app.Test(req)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, 404, resp.StatusCode)
|
utils.AssertEqual(t, 404, resp.StatusCode)
|
||||||
|
|
|
@ -13,18 +13,18 @@ import (
|
||||||
"github.com/gofiber/fiber/v2/utils"
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
func getFileExtension(path string) string {
|
func getFileExtension(p string) string {
|
||||||
n := strings.LastIndexByte(path, '.')
|
n := strings.LastIndexByte(p, '.')
|
||||||
if n < 0 {
|
if n < 0 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
return path[n:]
|
return p[n:]
|
||||||
}
|
}
|
||||||
|
|
||||||
func dirList(c *fiber.Ctx, f http.File) error {
|
func dirList(c *fiber.Ctx, f http.File) error {
|
||||||
fileinfos, err := f.Readdir(-1)
|
fileinfos, err := f.Readdir(-1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to read dir: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fm := make(map[string]os.FileInfo, len(fileinfos))
|
fm := make(map[string]os.FileInfo, len(fileinfos))
|
||||||
|
@ -36,13 +36,13 @@ func dirList(c *fiber.Ctx, f http.File) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
basePathEscaped := html.EscapeString(c.Path())
|
basePathEscaped := html.EscapeString(c.Path())
|
||||||
fmt.Fprintf(c, "<html><head><title>%s</title><style>.dir { font-weight: bold }</style></head><body>", basePathEscaped)
|
_, _ = fmt.Fprintf(c, "<html><head><title>%s</title><style>.dir { font-weight: bold }</style></head><body>", basePathEscaped)
|
||||||
fmt.Fprintf(c, "<h1>%s</h1>", basePathEscaped)
|
_, _ = fmt.Fprintf(c, "<h1>%s</h1>", basePathEscaped)
|
||||||
fmt.Fprint(c, "<ul>")
|
_, _ = fmt.Fprint(c, "<ul>")
|
||||||
|
|
||||||
if len(basePathEscaped) > 1 {
|
if len(basePathEscaped) > 1 {
|
||||||
parentPathEscaped := html.EscapeString(utils.TrimRight(c.Path(), '/') + "/..")
|
parentPathEscaped := html.EscapeString(utils.TrimRight(c.Path(), '/') + "/..")
|
||||||
fmt.Fprintf(c, `<li><a href="%s" class="dir">..</a></li>`, parentPathEscaped)
|
_, _ = fmt.Fprintf(c, `<li><a href="%s" class="dir">..</a></li>`, parentPathEscaped)
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.Strings(filenames)
|
sort.Strings(filenames)
|
||||||
|
@ -55,10 +55,10 @@ func dirList(c *fiber.Ctx, f http.File) error {
|
||||||
auxStr = fmt.Sprintf("file, %d bytes", fi.Size())
|
auxStr = fmt.Sprintf("file, %d bytes", fi.Size())
|
||||||
className = "file"
|
className = "file"
|
||||||
}
|
}
|
||||||
fmt.Fprintf(c, `<li><a href="%s" class="%s">%s</a>, %s, last modified %s</li>`,
|
_, _ = fmt.Fprintf(c, `<li><a href="%s" class="%s">%s</a>, %s, last modified %s</li>`,
|
||||||
pathEscaped, className, html.EscapeString(name), auxStr, fi.ModTime())
|
pathEscaped, className, html.EscapeString(name), auxStr, fi.ModTime())
|
||||||
}
|
}
|
||||||
fmt.Fprint(c, "</ul></body></html>")
|
_, _ = fmt.Fprint(c, "</ul></body></html>")
|
||||||
|
|
||||||
c.Type("html")
|
c.Type("html")
|
||||||
|
|
||||||
|
|
|
@ -9,9 +9,7 @@ import (
|
||||||
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var ErrInvalidIdempotencyKey = errors.New("invalid idempotency key")
|
||||||
ErrInvalidIdempotencyKey = errors.New("invalid idempotency key")
|
|
||||||
)
|
|
||||||
|
|
||||||
// Config defines the config for middleware.
|
// Config defines the config for middleware.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
|
@ -51,13 +49,15 @@ type Config struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConfigDefault is the default config
|
// ConfigDefault is the default config
|
||||||
|
//
|
||||||
|
//nolint:gochecknoglobals // Using a global var is fine here
|
||||||
var ConfigDefault = Config{
|
var ConfigDefault = Config{
|
||||||
Next: func(c *fiber.Ctx) bool {
|
Next: func(c *fiber.Ctx) bool {
|
||||||
// Skip middleware if the request was done using a safe HTTP method
|
// Skip middleware if the request was done using a safe HTTP method
|
||||||
return fiber.IsMethodSafe(c.Method())
|
return fiber.IsMethodSafe(c.Method())
|
||||||
},
|
},
|
||||||
|
|
||||||
Lifetime: 30 * time.Minute,
|
Lifetime: 30 * time.Minute, //nolint:gomnd // No magic number, just the default config
|
||||||
|
|
||||||
KeyHeader: "X-Idempotency-Key",
|
KeyHeader: "X-Idempotency-Key",
|
||||||
KeyHeaderValidate: func(k string) error {
|
KeyHeaderValidate: func(k string) error {
|
||||||
|
@ -112,7 +112,7 @@ func configDefault(config ...Config) Config {
|
||||||
|
|
||||||
if cfg.Storage == nil {
|
if cfg.Storage == nil {
|
||||||
cfg.Storage = memory.New(memory.Config{
|
cfg.Storage = memory.New(memory.Config{
|
||||||
GCInterval: cfg.Lifetime / 2,
|
GCInterval: cfg.Lifetime / 2, //nolint:gomnd // Half the lifetime interval
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
|
||||||
package idempotency_test
|
package idempotency_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
@ -14,6 +15,7 @@ import (
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/gofiber/fiber/v2/middleware/idempotency"
|
"github.com/gofiber/fiber/v2/middleware/idempotency"
|
||||||
"github.com/gofiber/fiber/v2/utils"
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -172,5 +174,4 @@ func Benchmark_Idempotency(b *testing.B) {
|
||||||
h(c)
|
h(c)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
package limiter
|
package limiter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"log"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
|
@ -58,19 +58,21 @@ type Config struct {
|
||||||
// Default: a new Fixed Window Rate Limiter
|
// Default: a new Fixed Window Rate Limiter
|
||||||
LimiterMiddleware LimiterHandler
|
LimiterMiddleware LimiterHandler
|
||||||
|
|
||||||
// DEPRECATED: Use Expiration instead
|
// Deprecated: Use Expiration instead
|
||||||
Duration time.Duration
|
Duration time.Duration
|
||||||
|
|
||||||
// DEPRECATED, use Storage instead
|
// Deprecated: Use Storage instead
|
||||||
Store fiber.Storage
|
Store fiber.Storage
|
||||||
|
|
||||||
// DEPRECATED, use KeyGenerator instead
|
// Deprecated: Use KeyGenerator instead
|
||||||
Key func(*fiber.Ctx) string
|
Key func(*fiber.Ctx) string
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConfigDefault is the default config
|
// ConfigDefault is the default config
|
||||||
|
//
|
||||||
|
//nolint:gochecknoglobals // Using a global var is fine here
|
||||||
var ConfigDefault = Config{
|
var ConfigDefault = Config{
|
||||||
Max: 5,
|
Max: 5, //nolint:gomnd // No magic number, just the default config
|
||||||
Expiration: 1 * time.Minute,
|
Expiration: 1 * time.Minute,
|
||||||
KeyGenerator: func(c *fiber.Ctx) string {
|
KeyGenerator: func(c *fiber.Ctx) string {
|
||||||
return c.IP()
|
return c.IP()
|
||||||
|
@ -95,15 +97,15 @@ func configDefault(config ...Config) Config {
|
||||||
|
|
||||||
// Set default values
|
// Set default values
|
||||||
if int(cfg.Duration.Seconds()) > 0 {
|
if int(cfg.Duration.Seconds()) > 0 {
|
||||||
fmt.Println("[LIMITER] Duration is deprecated, please use Expiration")
|
log.Printf("[LIMITER] Duration is deprecated, please use Expiration\n")
|
||||||
cfg.Expiration = cfg.Duration
|
cfg.Expiration = cfg.Duration
|
||||||
}
|
}
|
||||||
if cfg.Key != nil {
|
if cfg.Key != nil {
|
||||||
fmt.Println("[LIMITER] Key is deprecated, please us KeyGenerator")
|
log.Printf("[LIMITER] Key is deprecated, please us KeyGenerator\n")
|
||||||
cfg.KeyGenerator = cfg.Key
|
cfg.KeyGenerator = cfg.Key
|
||||||
}
|
}
|
||||||
if cfg.Store != nil {
|
if cfg.Store != nil {
|
||||||
fmt.Println("[LIMITER] Store is deprecated, please use Storage")
|
log.Printf("[LIMITER] Store is deprecated, please use Storage\n")
|
||||||
cfg.Storage = cfg.Store
|
cfg.Storage = cfg.Store
|
||||||
}
|
}
|
||||||
if cfg.Next == nil {
|
if cfg.Next == nil {
|
||||||
|
|
|
@ -2,7 +2,6 @@ package limiter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -11,6 +10,7 @@ import (
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
||||||
"github.com/gofiber/fiber/v2/utils"
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ func Test_Limiter_Concurrency_Store(t *testing.T) {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
singleRequest := func(wg *sync.WaitGroup) {
|
singleRequest := func(wg *sync.WaitGroup) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
@ -50,13 +50,13 @@ func Test_Limiter_Concurrency_Store(t *testing.T) {
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||||
|
|
||||||
time.Sleep(3 * time.Second)
|
time.Sleep(3 * time.Second)
|
||||||
|
|
||||||
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
|
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
@ -80,7 +80,7 @@ func Test_Limiter_Concurrency(t *testing.T) {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
singleRequest := func(wg *sync.WaitGroup) {
|
singleRequest := func(wg *sync.WaitGroup) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
@ -96,13 +96,13 @@ func Test_Limiter_Concurrency(t *testing.T) {
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||||
|
|
||||||
time.Sleep(3 * time.Second)
|
time.Sleep(3 * time.Second)
|
||||||
|
|
||||||
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
|
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
@ -120,21 +120,21 @@ func Test_Limiter_No_Skip_Choices(t *testing.T) {
|
||||||
}))
|
}))
|
||||||
|
|
||||||
app.Get("/:status", func(c *fiber.Ctx) error {
|
app.Get("/:status", func(c *fiber.Ctx) error {
|
||||||
if c.Params("status") == "fail" {
|
if c.Params("status") == "fail" { //nolint:goconst // False positive
|
||||||
return c.SendStatus(400)
|
return c.SendStatus(400)
|
||||||
}
|
}
|
||||||
return c.SendStatus(200)
|
return c.SendStatus(200)
|
||||||
})
|
})
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||||
|
|
||||||
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
|
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||||
|
|
||||||
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
|
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
@ -157,21 +157,21 @@ func Test_Limiter_Skip_Failed_Requests(t *testing.T) {
|
||||||
return c.SendStatus(200)
|
return c.SendStatus(200)
|
||||||
})
|
})
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||||
|
|
||||||
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
|
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||||
|
|
||||||
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
|
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||||
|
|
||||||
time.Sleep(3 * time.Second)
|
time.Sleep(3 * time.Second)
|
||||||
|
|
||||||
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
|
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
@ -196,21 +196,21 @@ func Test_Limiter_Skip_Successful_Requests(t *testing.T) {
|
||||||
return c.SendStatus(200)
|
return c.SendStatus(200)
|
||||||
})
|
})
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||||
|
|
||||||
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil))
|
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||||
|
|
||||||
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil))
|
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||||
|
|
||||||
time.Sleep(3 * time.Second)
|
time.Sleep(3 * time.Second)
|
||||||
|
|
||||||
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil))
|
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
@ -232,7 +232,7 @@ func Benchmark_Limiter_Custom_Store(b *testing.B) {
|
||||||
h := app.Handler()
|
h := app.Handler()
|
||||||
|
|
||||||
fctx := &fasthttp.RequestCtx{}
|
fctx := &fasthttp.RequestCtx{}
|
||||||
fctx.Request.Header.SetMethod("GET")
|
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||||
fctx.Request.SetRequestURI("/")
|
fctx.Request.SetRequestURI("/")
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
@ -252,7 +252,7 @@ func Test_Limiter_Next(t *testing.T) {
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
@ -271,7 +271,7 @@ func Test_Limiter_Headers(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
fctx := &fasthttp.RequestCtx{}
|
fctx := &fasthttp.RequestCtx{}
|
||||||
fctx.Request.Header.SetMethod("GET")
|
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||||
fctx.Request.SetRequestURI("/")
|
fctx.Request.SetRequestURI("/")
|
||||||
|
|
||||||
app.Handler()(fctx)
|
app.Handler()(fctx)
|
||||||
|
@ -301,7 +301,7 @@ func Benchmark_Limiter(b *testing.B) {
|
||||||
h := app.Handler()
|
h := app.Handler()
|
||||||
|
|
||||||
fctx := &fasthttp.RequestCtx{}
|
fctx := &fasthttp.RequestCtx{}
|
||||||
fctx.Request.Header.SetMethod("GET")
|
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||||
fctx.Request.SetRequestURI("/")
|
fctx.Request.SetRequestURI("/")
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
@ -327,7 +327,7 @@ func Test_Sliding_Window(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
singleRequest := func(shouldFail bool) {
|
singleRequest := func(shouldFail bool) {
|
||||||
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
if shouldFail {
|
if shouldFail {
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||||
|
|
|
@ -46,7 +46,7 @@ func newManager(storage fiber.Storage) *manager {
|
||||||
|
|
||||||
// acquire returns an *entry from the sync.Pool
|
// acquire returns an *entry from the sync.Pool
|
||||||
func (m *manager) acquire() *item {
|
func (m *manager) acquire() *item {
|
||||||
return m.pool.Get().(*item)
|
return m.pool.Get().(*item) //nolint:forcetypeassert // We store nothing else in the pool
|
||||||
}
|
}
|
||||||
|
|
||||||
// release and reset *entry to sync.Pool
|
// release and reset *entry to sync.Pool
|
||||||
|
@ -58,37 +58,33 @@ func (m *manager) release(e *item) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// get data from storage or memory
|
// get data from storage or memory
|
||||||
func (m *manager) get(key string) (it *item) {
|
func (m *manager) get(key string) *item {
|
||||||
|
var it *item
|
||||||
if m.storage != nil {
|
if m.storage != nil {
|
||||||
it = m.acquire()
|
it = m.acquire()
|
||||||
if raw, _ := m.storage.Get(key); raw != nil {
|
raw, err := m.storage.Get(key)
|
||||||
|
if err != nil {
|
||||||
|
return it
|
||||||
|
}
|
||||||
|
if raw != nil {
|
||||||
if _, err := it.UnmarshalMsg(raw); err != nil {
|
if _, err := it.UnmarshalMsg(raw); err != nil {
|
||||||
return
|
return it
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return it
|
||||||
}
|
}
|
||||||
if it, _ = m.memory.Get(key).(*item); it == nil {
|
if it, _ = m.memory.Get(key).(*item); it == nil { //nolint:errcheck // We store nothing else in the pool
|
||||||
it = m.acquire()
|
it = m.acquire()
|
||||||
|
return it
|
||||||
}
|
}
|
||||||
return
|
return it
|
||||||
}
|
|
||||||
|
|
||||||
// get raw data from storage or memory
|
|
||||||
func (m *manager) getRaw(key string) (raw []byte) {
|
|
||||||
if m.storage != nil {
|
|
||||||
raw, _ = m.storage.Get(key)
|
|
||||||
} else {
|
|
||||||
raw, _ = m.memory.Get(key).([]byte)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// set data to storage or memory
|
// set data to storage or memory
|
||||||
func (m *manager) set(key string, it *item, exp time.Duration) {
|
func (m *manager) set(key string, it *item, exp time.Duration) {
|
||||||
if m.storage != nil {
|
if m.storage != nil {
|
||||||
if raw, err := it.MarshalMsg(nil); err == nil {
|
if raw, err := it.MarshalMsg(nil); err == nil {
|
||||||
_ = m.storage.Set(key, raw, exp)
|
_ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Handle error here
|
||||||
}
|
}
|
||||||
// we can release data because it's serialized to database
|
// we can release data because it's serialized to database
|
||||||
m.release(it)
|
m.release(it)
|
||||||
|
@ -96,21 +92,3 @@ func (m *manager) set(key string, it *item, exp time.Duration) {
|
||||||
m.memory.Set(key, it, exp)
|
m.memory.Set(key, it, exp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// set data to storage or memory
|
|
||||||
func (m *manager) setRaw(key string, raw []byte, exp time.Duration) {
|
|
||||||
if m.storage != nil {
|
|
||||||
_ = m.storage.Set(key, raw, exp)
|
|
||||||
} else {
|
|
||||||
m.memory.Set(key, raw, exp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// delete data from storage or memory
|
|
||||||
func (m *manager) delete(key string) {
|
|
||||||
if m.storage != nil {
|
|
||||||
_ = m.storage.Delete(key)
|
|
||||||
} else {
|
|
||||||
m.memory.Delete(key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -189,7 +189,7 @@ const (
|
||||||
TagBytesReceived = "bytesReceived"
|
TagBytesReceived = "bytesReceived"
|
||||||
TagRoute = "route"
|
TagRoute = "route"
|
||||||
TagError = "error"
|
TagError = "error"
|
||||||
// DEPRECATED: Use TagReqHeader instead
|
// Deprecated: Use TagReqHeader instead
|
||||||
TagHeader = "header:" // request header
|
TagHeader = "header:" // request header
|
||||||
TagReqHeader = "reqHeader:" // request header
|
TagReqHeader = "reqHeader:" // request header
|
||||||
TagRespHeader = "respHeader:" // response header
|
TagRespHeader = "respHeader:" // response header
|
||||||
|
|
|
@ -79,13 +79,15 @@ type Buffer interface {
|
||||||
type LogFunc func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error)
|
type LogFunc func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error)
|
||||||
|
|
||||||
// ConfigDefault is the default config
|
// ConfigDefault is the default config
|
||||||
|
//
|
||||||
|
//nolint:gochecknoglobals // Using a global var is fine here
|
||||||
var ConfigDefault = Config{
|
var ConfigDefault = Config{
|
||||||
Next: nil,
|
Next: nil,
|
||||||
Done: nil,
|
Done: nil,
|
||||||
Format: "[${time}] ${status} - ${latency} ${method} ${path}\n",
|
Format: "[${time}] ${status} - ${latency} ${method} ${path}\n",
|
||||||
TimeFormat: "15:04:05",
|
TimeFormat: "15:04:05",
|
||||||
TimeZone: "Local",
|
TimeZone: "Local",
|
||||||
TimeInterval: 500 * time.Millisecond,
|
TimeInterval: 500 * time.Millisecond, //nolint:gomnd // No magic number, just the default config
|
||||||
Output: os.Stdout,
|
Output: os.Stdout,
|
||||||
enableColors: true,
|
enableColors: true,
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,13 +1,10 @@
|
||||||
package logger
|
package logger
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var DataPool = sync.Pool{New: func() interface{} { return new(Data) }}
|
|
||||||
|
|
||||||
// Data is a struct to define some variables to use in custom logger function.
|
// Data is a struct to define some variables to use in custom logger function.
|
||||||
type Data struct {
|
type Data struct {
|
||||||
Pid string
|
Pid string
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/gofiber/fiber/v2/utils"
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
|
|
||||||
"github.com/mattn/go-colorable"
|
"github.com/mattn/go-colorable"
|
||||||
"github.com/mattn/go-isatty"
|
"github.com/mattn/go-isatty"
|
||||||
"github.com/valyala/bytebufferpool"
|
"github.com/valyala/bytebufferpool"
|
||||||
|
@ -55,6 +56,8 @@ func New(config ...Config) fiber.Handler {
|
||||||
once sync.Once
|
once sync.Once
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
errHandler fiber.ErrorHandler
|
errHandler fiber.ErrorHandler
|
||||||
|
|
||||||
|
dataPool = sync.Pool{New: func() interface{} { return new(Data) }}
|
||||||
)
|
)
|
||||||
|
|
||||||
// If colors are enabled, check terminal compatibility
|
// If colors are enabled, check terminal compatibility
|
||||||
|
@ -75,7 +78,7 @@ func New(config ...Config) fiber.Handler {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return new handler
|
// Return new handler
|
||||||
return func(c *fiber.Ctx) (err error) {
|
return func(c *fiber.Ctx) error {
|
||||||
// Don't execute middleware if Next returns true
|
// Don't execute middleware if Next returns true
|
||||||
if cfg.Next != nil && cfg.Next(c) {
|
if cfg.Next != nil && cfg.Next(c) {
|
||||||
return c.Next()
|
return c.Next()
|
||||||
|
@ -101,13 +104,13 @@ func New(config ...Config) fiber.Handler {
|
||||||
})
|
})
|
||||||
|
|
||||||
// Logger data
|
// Logger data
|
||||||
data := DataPool.Get().(*Data)
|
data := dataPool.Get().(*Data) //nolint:forcetypeassert,errcheck // We store nothing else in the pool
|
||||||
// no need for a reset, as long as we always override everything
|
// no need for a reset, as long as we always override everything
|
||||||
data.Pid = pid
|
data.Pid = pid
|
||||||
data.ErrPaddingStr = errPaddingStr
|
data.ErrPaddingStr = errPaddingStr
|
||||||
data.Timestamp = timestamp
|
data.Timestamp = timestamp
|
||||||
// put data back in the pool
|
// put data back in the pool
|
||||||
defer DataPool.Put(data)
|
defer dataPool.Put(data)
|
||||||
|
|
||||||
// Set latency start time
|
// Set latency start time
|
||||||
if cfg.enableLatency {
|
if cfg.enableLatency {
|
||||||
|
@ -121,7 +124,7 @@ func New(config ...Config) fiber.Handler {
|
||||||
// Manually call error handler
|
// Manually call error handler
|
||||||
if chainErr != nil {
|
if chainErr != nil {
|
||||||
if err := errHandler(c, chainErr); err != nil {
|
if err := errHandler(c, chainErr); err != nil {
|
||||||
_ = c.SendStatus(fiber.StatusInternalServerError)
|
_ = c.SendStatus(fiber.StatusInternalServerError) //nolint:errcheck // TODO: Explain why we ignore the error here
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -142,7 +145,8 @@ func New(config ...Config) fiber.Handler {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Format log to buffer
|
// Format log to buffer
|
||||||
_, _ = buf.WriteString(fmt.Sprintf("%s |%s %3d %s| %7v | %15s |%s %-7s %s| %-"+errPaddingStr+"s %s\n",
|
_, _ = buf.WriteString( //nolint:errcheck // This will never fail
|
||||||
|
fmt.Sprintf("%s |%s %3d %s| %7v | %15s |%s %-7s %s| %-"+errPaddingStr+"s %s\n",
|
||||||
timestamp.Load().(string),
|
timestamp.Load().(string),
|
||||||
statusColor(c.Response().StatusCode(), colors), c.Response().StatusCode(), colors.Reset,
|
statusColor(c.Response().StatusCode(), colors), c.Response().StatusCode(), colors.Reset,
|
||||||
data.Stop.Sub(data.Start).Round(time.Millisecond),
|
data.Stop.Sub(data.Start).Round(time.Millisecond),
|
||||||
|
@ -150,10 +154,11 @@ func New(config ...Config) fiber.Handler {
|
||||||
methodColor(c.Method(), colors), c.Method(), colors.Reset,
|
methodColor(c.Method(), colors), c.Method(), colors.Reset,
|
||||||
c.Path(),
|
c.Path(),
|
||||||
formatErr,
|
formatErr,
|
||||||
))
|
),
|
||||||
|
)
|
||||||
|
|
||||||
// Write buffer to output
|
// Write buffer to output
|
||||||
_, _ = cfg.Output.Write(buf.Bytes())
|
_, _ = cfg.Output.Write(buf.Bytes()) //nolint:errcheck // This will never fail
|
||||||
|
|
||||||
if cfg.Done != nil {
|
if cfg.Done != nil {
|
||||||
cfg.Done(c, buf.Bytes())
|
cfg.Done(c, buf.Bytes())
|
||||||
|
@ -169,7 +174,7 @@ func New(config ...Config) fiber.Handler {
|
||||||
// Loop over template parts execute dynamic parts and add fixed parts to the buffer
|
// Loop over template parts execute dynamic parts and add fixed parts to the buffer
|
||||||
for i, logFunc := range logFunChain {
|
for i, logFunc := range logFunChain {
|
||||||
if logFunc == nil {
|
if logFunc == nil {
|
||||||
_, _ = buf.Write(templateChain[i])
|
_, _ = buf.Write(templateChain[i]) //nolint:errcheck // This will never fail
|
||||||
} else if templateChain[i] == nil {
|
} else if templateChain[i] == nil {
|
||||||
_, err = logFunc(buf, c, data, "")
|
_, err = logFunc(buf, c, data, "")
|
||||||
} else {
|
} else {
|
||||||
|
@ -182,7 +187,7 @@ func New(config ...Config) fiber.Handler {
|
||||||
|
|
||||||
// Also write errors to the buffer
|
// Also write errors to the buffer
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_, _ = buf.WriteString(err.Error())
|
_, _ = buf.WriteString(err.Error()) //nolint:errcheck // This will never fail
|
||||||
}
|
}
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
// Write buffer to output
|
// Write buffer to output
|
||||||
|
@ -190,7 +195,7 @@ func New(config ...Config) fiber.Handler {
|
||||||
// Write error to output
|
// Write error to output
|
||||||
if _, err := cfg.Output.Write([]byte(err.Error())); err != nil {
|
if _, err := cfg.Output.Write([]byte(err.Error())); err != nil {
|
||||||
// There is something wrong with the given io.Writer
|
// There is something wrong with the given io.Writer
|
||||||
fmt.Fprintf(os.Stderr, "Failed to write to log, %v\n", err)
|
_, _ = fmt.Fprintf(os.Stderr, "Failed to write to log, %v\n", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
mu.Unlock()
|
mu.Unlock()
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
|
||||||
package logger
|
package logger
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
@ -16,6 +17,7 @@ import (
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/gofiber/fiber/v2/middleware/requestid"
|
"github.com/gofiber/fiber/v2/middleware/requestid"
|
||||||
"github.com/gofiber/fiber/v2/utils"
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
|
|
||||||
"github.com/valyala/bytebufferpool"
|
"github.com/valyala/bytebufferpool"
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
@ -37,7 +39,7 @@ func Test_Logger(t *testing.T) {
|
||||||
return errors.New("some random error")
|
return errors.New("some random error")
|
||||||
})
|
})
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
|
||||||
utils.AssertEqual(t, "some random error", buf.String())
|
utils.AssertEqual(t, "some random error", buf.String())
|
||||||
|
@ -70,21 +72,21 @@ func Test_Logger_locals(t *testing.T) {
|
||||||
return c.SendStatus(fiber.StatusOK)
|
return c.SendStatus(fiber.StatusOK)
|
||||||
})
|
})
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||||
utils.AssertEqual(t, "johndoe", buf.String())
|
utils.AssertEqual(t, "johndoe", buf.String())
|
||||||
|
|
||||||
buf.Reset()
|
buf.Reset()
|
||||||
|
|
||||||
resp, err = app.Test(httptest.NewRequest("GET", "/int", nil))
|
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/int", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||||
utils.AssertEqual(t, "55", buf.String())
|
utils.AssertEqual(t, "55", buf.String())
|
||||||
|
|
||||||
buf.Reset()
|
buf.Reset()
|
||||||
|
|
||||||
resp, err = app.Test(httptest.NewRequest("GET", "/empty", nil))
|
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/empty", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||||
utils.AssertEqual(t, "", buf.String())
|
utils.AssertEqual(t, "", buf.String())
|
||||||
|
@ -100,7 +102,7 @@ func Test_Logger_Next(t *testing.T) {
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
@ -113,15 +115,15 @@ func Test_Logger_Done(t *testing.T) {
|
||||||
app.Use(New(Config{
|
app.Use(New(Config{
|
||||||
Done: func(c *fiber.Ctx, logString []byte) {
|
Done: func(c *fiber.Ctx, logString []byte) {
|
||||||
if c.Response().StatusCode() == fiber.StatusOK {
|
if c.Response().StatusCode() == fiber.StatusOK {
|
||||||
buf.Write(logString)
|
_, err := buf.Write(logString)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
})).Get("/logging", func(ctx *fiber.Ctx) error {
|
})).Get("/logging", func(ctx *fiber.Ctx) error {
|
||||||
return ctx.SendStatus(fiber.StatusOK)
|
return ctx.SendStatus(fiber.StatusOK)
|
||||||
})
|
})
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/logging", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/logging", nil))
|
||||||
|
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||||
utils.AssertEqual(t, true, buf.Len() > 0)
|
utils.AssertEqual(t, true, buf.Len() > 0)
|
||||||
|
@ -135,7 +137,7 @@ func Test_Logger_ErrorTimeZone(t *testing.T) {
|
||||||
TimeZone: "invalid",
|
TimeZone: "invalid",
|
||||||
}))
|
}))
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
@ -156,7 +158,7 @@ func Test_Logger_ErrorOutput(t *testing.T) {
|
||||||
Output: o,
|
Output: o,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||||
|
|
||||||
|
@ -178,7 +180,7 @@ func Test_Logger_All(t *testing.T) {
|
||||||
// Alias colors
|
// Alias colors
|
||||||
colors := app.Config().ColorScheme
|
colors := app.Config().ColorScheme
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/?foo=bar", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?foo=bar", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||||
|
|
||||||
|
@ -198,7 +200,7 @@ func Test_Query_Params(t *testing.T) {
|
||||||
Output: buf,
|
Output: buf,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/?foo=bar&baz=moz", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?foo=bar&baz=moz", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||||
|
|
||||||
|
@ -226,7 +228,7 @@ func Test_Response_Body(t *testing.T) {
|
||||||
return c.Send([]byte("Post in test"))
|
return c.Send([]byte("Post in test"))
|
||||||
})
|
})
|
||||||
|
|
||||||
_, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
expectedGetResponse := "Sample response body"
|
expectedGetResponse := "Sample response body"
|
||||||
|
@ -234,7 +236,7 @@ func Test_Response_Body(t *testing.T) {
|
||||||
|
|
||||||
buf.Reset() // Reset buffer to test POST
|
buf.Reset() // Reset buffer to test POST
|
||||||
|
|
||||||
_, err = app.Test(httptest.NewRequest("POST", "/test", nil))
|
_, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/test", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
expectedPostResponse := "Post in test"
|
expectedPostResponse := "Post in test"
|
||||||
|
@ -258,7 +260,7 @@ func Test_Logger_AppendUint(t *testing.T) {
|
||||||
return c.SendString("hello")
|
return c.SendString("hello")
|
||||||
})
|
})
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||||
utils.AssertEqual(t, "0 5 200", buf.String())
|
utils.AssertEqual(t, "0 5 200", buf.String())
|
||||||
|
@ -285,12 +287,11 @@ func Test_Logger_Data_Race(t *testing.T) {
|
||||||
wg := &sync.WaitGroup{}
|
wg := &sync.WaitGroup{}
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
resp1, err1 = app.Test(httptest.NewRequest("GET", "/", nil))
|
resp1, err1 = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
wg.Done()
|
wg.Done()
|
||||||
}()
|
}()
|
||||||
resp2, err2 = app.Test(httptest.NewRequest("GET", "/", nil))
|
resp2, err2 = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
utils.AssertEqual(t, nil, err1)
|
utils.AssertEqual(t, nil, err1)
|
||||||
utils.AssertEqual(t, fiber.StatusOK, resp1.StatusCode)
|
utils.AssertEqual(t, fiber.StatusOK, resp1.StatusCode)
|
||||||
utils.AssertEqual(t, nil, err2)
|
utils.AssertEqual(t, nil, err2)
|
||||||
|
@ -299,21 +300,23 @@ func Test_Logger_Data_Race(t *testing.T) {
|
||||||
|
|
||||||
// go test -v -run=^$ -bench=Benchmark_Logger -benchmem -count=4
|
// go test -v -run=^$ -bench=Benchmark_Logger -benchmem -count=4
|
||||||
func Benchmark_Logger(b *testing.B) {
|
func Benchmark_Logger(b *testing.B) {
|
||||||
benchSetup := func(bb *testing.B, app *fiber.App) {
|
benchSetup := func(b *testing.B, app *fiber.App) {
|
||||||
|
b.Helper()
|
||||||
|
|
||||||
h := app.Handler()
|
h := app.Handler()
|
||||||
|
|
||||||
fctx := &fasthttp.RequestCtx{}
|
fctx := &fasthttp.RequestCtx{}
|
||||||
fctx.Request.Header.SetMethod("GET")
|
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||||
fctx.Request.SetRequestURI("/")
|
fctx.Request.SetRequestURI("/")
|
||||||
|
|
||||||
bb.ReportAllocs()
|
b.ReportAllocs()
|
||||||
bb.ResetTimer()
|
b.ResetTimer()
|
||||||
|
|
||||||
for n := 0; n < bb.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
h(fctx)
|
h(fctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
utils.AssertEqual(bb, 200, fctx.Response.Header.StatusCode())
|
utils.AssertEqual(b, 200, fctx.Response.Header.StatusCode())
|
||||||
}
|
}
|
||||||
|
|
||||||
b.Run("Base", func(bb *testing.B) {
|
b.Run("Base", func(bb *testing.B) {
|
||||||
|
@ -375,8 +378,7 @@ func Test_Response_Header(t *testing.T) {
|
||||||
return c.SendString("Hello fiber!")
|
return c.SendString("Hello fiber!")
|
||||||
})
|
})
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
|
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||||
utils.AssertEqual(t, "Hello fiber!", buf.String())
|
utils.AssertEqual(t, "Hello fiber!", buf.String())
|
||||||
|
@ -396,10 +398,10 @@ func Test_Req_Header(t *testing.T) {
|
||||||
app.Get("/", func(c *fiber.Ctx) error {
|
app.Get("/", func(c *fiber.Ctx) error {
|
||||||
return c.SendString("Hello fiber!")
|
return c.SendString("Hello fiber!")
|
||||||
})
|
})
|
||||||
headerReq := httptest.NewRequest("GET", "/", nil)
|
headerReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||||
headerReq.Header.Add("test", "Hello fiber!")
|
headerReq.Header.Add("test", "Hello fiber!")
|
||||||
resp, err := app.Test(headerReq)
|
|
||||||
|
|
||||||
|
resp, err := app.Test(headerReq)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||||
utils.AssertEqual(t, "Hello fiber!", buf.String())
|
utils.AssertEqual(t, "Hello fiber!", buf.String())
|
||||||
|
@ -419,10 +421,10 @@ func Test_ReqHeader_Header(t *testing.T) {
|
||||||
app.Get("/", func(c *fiber.Ctx) error {
|
app.Get("/", func(c *fiber.Ctx) error {
|
||||||
return c.SendString("Hello fiber!")
|
return c.SendString("Hello fiber!")
|
||||||
})
|
})
|
||||||
reqHeaderReq := httptest.NewRequest("GET", "/", nil)
|
reqHeaderReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||||
reqHeaderReq.Header.Add("test", "Hello fiber!")
|
reqHeaderReq.Header.Add("test", "Hello fiber!")
|
||||||
resp, err := app.Test(reqHeaderReq)
|
|
||||||
|
|
||||||
|
resp, err := app.Test(reqHeaderReq)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||||
utils.AssertEqual(t, "Hello fiber!", buf.String())
|
utils.AssertEqual(t, "Hello fiber!", buf.String())
|
||||||
|
@ -449,10 +451,10 @@ func Test_CustomTags(t *testing.T) {
|
||||||
app.Get("/", func(c *fiber.Ctx) error {
|
app.Get("/", func(c *fiber.Ctx) error {
|
||||||
return c.SendString("Hello fiber!")
|
return c.SendString("Hello fiber!")
|
||||||
})
|
})
|
||||||
reqHeaderReq := httptest.NewRequest("GET", "/", nil)
|
reqHeaderReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||||
reqHeaderReq.Header.Add("test", "Hello fiber!")
|
reqHeaderReq.Header.Add("test", "Hello fiber!")
|
||||||
resp, err := app.Test(reqHeaderReq)
|
|
||||||
|
|
||||||
|
resp, err := app.Test(reqHeaderReq)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||||
utils.AssertEqual(t, customTag, buf.String())
|
utils.AssertEqual(t, customTag, buf.String())
|
||||||
|
@ -492,7 +494,7 @@ func Test_Logger_ByteSent_Streaming(t *testing.T) {
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||||
utils.AssertEqual(t, "0 0 200", buf.String())
|
utils.AssertEqual(t, "0 0 200", buf.String())
|
||||||
|
|
|
@ -31,7 +31,7 @@ const (
|
||||||
TagBytesReceived = "bytesReceived"
|
TagBytesReceived = "bytesReceived"
|
||||||
TagRoute = "route"
|
TagRoute = "route"
|
||||||
TagError = "error"
|
TagError = "error"
|
||||||
// DEPRECATED: Use TagReqHeader instead
|
// Deprecated: Use TagReqHeader instead
|
||||||
TagHeader = "header:"
|
TagHeader = "header:"
|
||||||
TagReqHeader = "reqHeader:"
|
TagReqHeader = "reqHeader:"
|
||||||
TagRespHeader = "respHeader:"
|
TagRespHeader = "respHeader:"
|
||||||
|
@ -195,7 +195,7 @@ func createTagMap(cfg *Config) map[string]LogFunc {
|
||||||
return output.WriteString(fmt.Sprintf("%7v", latency))
|
return output.WriteString(fmt.Sprintf("%7v", latency))
|
||||||
},
|
},
|
||||||
TagTime: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
TagTime: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||||
return output.WriteString(data.Timestamp.Load().(string))
|
return output.WriteString(data.Timestamp.Load().(string)) //nolint:forcetypeassert // We always store a string in here
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
// merge with custom tags from user
|
// merge with custom tags from user
|
||||||
|
|
|
@ -14,13 +14,16 @@ import (
|
||||||
// funcChain contains for the parts which exist the functions for the dynamic parts
|
// funcChain contains for the parts which exist the functions for the dynamic parts
|
||||||
// funcChain and fixParts always have the same length and contain nil for the parts where no data is required in the chain,
|
// funcChain and fixParts always have the same length and contain nil for the parts where no data is required in the chain,
|
||||||
// if a function exists for the part, a parameter for it can also exist in the fixParts slice
|
// if a function exists for the part, a parameter for it can also exist in the fixParts slice
|
||||||
func buildLogFuncChain(cfg *Config, tagFunctions map[string]LogFunc) (fixParts [][]byte, funcChain []LogFunc, err error) {
|
func buildLogFuncChain(cfg *Config, tagFunctions map[string]LogFunc) ([][]byte, []LogFunc, error) {
|
||||||
// process flow is copied from the fasttemplate flow https://github.com/valyala/fasttemplate/blob/2a2d1afadadf9715bfa19683cdaeac8347e5d9f9/template.go#L23-L62
|
// process flow is copied from the fasttemplate flow https://github.com/valyala/fasttemplate/blob/2a2d1afadadf9715bfa19683cdaeac8347e5d9f9/template.go#L23-L62
|
||||||
templateB := utils.UnsafeBytes(cfg.Format)
|
templateB := utils.UnsafeBytes(cfg.Format)
|
||||||
startTagB := utils.UnsafeBytes(startTag)
|
startTagB := utils.UnsafeBytes(startTag)
|
||||||
endTagB := utils.UnsafeBytes(endTag)
|
endTagB := utils.UnsafeBytes(endTag)
|
||||||
paramSeparatorB := utils.UnsafeBytes(paramSeparator)
|
paramSeparatorB := utils.UnsafeBytes(paramSeparator)
|
||||||
|
|
||||||
|
var fixParts [][]byte
|
||||||
|
var funcChain []LogFunc
|
||||||
|
|
||||||
for {
|
for {
|
||||||
currentPos := bytes.Index(templateB, startTagB)
|
currentPos := bytes.Index(templateB, startTagB)
|
||||||
if currentPos < 0 {
|
if currentPos < 0 {
|
||||||
|
@ -42,13 +45,13 @@ func buildLogFuncChain(cfg *Config, tagFunctions map[string]LogFunc) (fixParts [
|
||||||
// ## function block ##
|
// ## function block ##
|
||||||
// first check for tags with parameters
|
// first check for tags with parameters
|
||||||
if index := bytes.Index(templateB[:currentPos], paramSeparatorB); index != -1 {
|
if index := bytes.Index(templateB[:currentPos], paramSeparatorB); index != -1 {
|
||||||
if logFunc, ok := tagFunctions[utils.UnsafeString(templateB[:index+1])]; ok {
|
logFunc, ok := tagFunctions[utils.UnsafeString(templateB[:index+1])]
|
||||||
|
if !ok {
|
||||||
|
return nil, nil, errors.New("No parameter found in \"" + utils.UnsafeString(templateB[:currentPos]) + "\"")
|
||||||
|
}
|
||||||
funcChain = append(funcChain, logFunc)
|
funcChain = append(funcChain, logFunc)
|
||||||
// add param to the fixParts
|
// add param to the fixParts
|
||||||
fixParts = append(fixParts, templateB[index+1:currentPos])
|
fixParts = append(fixParts, templateB[index+1:currentPos])
|
||||||
} else {
|
|
||||||
return nil, nil, errors.New("No parameter found in \"" + utils.UnsafeString(templateB[:currentPos]) + "\"")
|
|
||||||
}
|
|
||||||
} else if logFunc, ok := tagFunctions[utils.UnsafeString(templateB[:currentPos])]; ok {
|
} else if logFunc, ok := tagFunctions[utils.UnsafeString(templateB[:currentPos])]; ok {
|
||||||
// add functions without parameter
|
// add functions without parameter
|
||||||
funcChain = append(funcChain, logFunc)
|
funcChain = append(funcChain, logFunc)
|
||||||
|
@ -63,5 +66,5 @@ func buildLogFuncChain(cfg *Config, tagFunctions map[string]LogFunc) (fixParts [
|
||||||
funcChain = append(funcChain, nil)
|
funcChain = append(funcChain, nil)
|
||||||
fixParts = append(fixParts, templateB)
|
fixParts = append(fixParts, templateB)
|
||||||
|
|
||||||
return
|
return fixParts, funcChain, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,36 +41,48 @@ type Config struct {
|
||||||
// ChartJsURL for specify ChartJS library path or URL . also you can use relative path
|
// ChartJsURL for specify ChartJS library path or URL . also you can use relative path
|
||||||
//
|
//
|
||||||
// Optional. Default: https://cdn.jsdelivr.net/npm/chart.js@2.9/dist/Chart.bundle.min.js
|
// Optional. Default: https://cdn.jsdelivr.net/npm/chart.js@2.9/dist/Chart.bundle.min.js
|
||||||
ChartJsURL string
|
ChartJSURL string
|
||||||
|
|
||||||
index string
|
index string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//nolint:gochecknoglobals // Using a global var is fine here
|
||||||
var ConfigDefault = Config{
|
var ConfigDefault = Config{
|
||||||
Title: defaultTitle,
|
Title: defaultTitle,
|
||||||
Refresh: defaultRefresh,
|
Refresh: defaultRefresh,
|
||||||
FontURL: defaultFontURL,
|
FontURL: defaultFontURL,
|
||||||
ChartJsURL: defaultChartJsURL,
|
ChartJSURL: defaultChartJSURL,
|
||||||
CustomHead: defaultCustomHead,
|
CustomHead: defaultCustomHead,
|
||||||
APIOnly: false,
|
APIOnly: false,
|
||||||
Next: nil,
|
Next: nil,
|
||||||
index: newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJsURL,
|
index: newIndex(viewBag{
|
||||||
defaultCustomHead}),
|
defaultTitle,
|
||||||
|
defaultRefresh,
|
||||||
|
defaultFontURL,
|
||||||
|
defaultChartJSURL,
|
||||||
|
defaultCustomHead,
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
|
|
||||||
func configDefault(config ...Config) Config {
|
func configDefault(config ...Config) Config {
|
||||||
// Users can change ConfigDefault.Title/Refresh which then
|
// Users can change ConfigDefault.Title/Refresh which then
|
||||||
// become incompatible with ConfigDefault.index
|
// become incompatible with ConfigDefault.index
|
||||||
if ConfigDefault.Title != defaultTitle || ConfigDefault.Refresh != defaultRefresh ||
|
if ConfigDefault.Title != defaultTitle ||
|
||||||
ConfigDefault.FontURL != defaultFontURL || ConfigDefault.ChartJsURL != defaultChartJsURL ||
|
ConfigDefault.Refresh != defaultRefresh ||
|
||||||
|
ConfigDefault.FontURL != defaultFontURL ||
|
||||||
|
ConfigDefault.ChartJSURL != defaultChartJSURL ||
|
||||||
ConfigDefault.CustomHead != defaultCustomHead {
|
ConfigDefault.CustomHead != defaultCustomHead {
|
||||||
|
|
||||||
if ConfigDefault.Refresh < minRefresh {
|
if ConfigDefault.Refresh < minRefresh {
|
||||||
ConfigDefault.Refresh = minRefresh
|
ConfigDefault.Refresh = minRefresh
|
||||||
}
|
}
|
||||||
// update default index with new default title/refresh
|
// update default index with new default title/refresh
|
||||||
ConfigDefault.index = newIndex(viewBag{ConfigDefault.Title,
|
ConfigDefault.index = newIndex(viewBag{
|
||||||
ConfigDefault.Refresh, ConfigDefault.FontURL, ConfigDefault.ChartJsURL, ConfigDefault.CustomHead})
|
ConfigDefault.Title,
|
||||||
|
ConfigDefault.Refresh,
|
||||||
|
ConfigDefault.FontURL,
|
||||||
|
ConfigDefault.ChartJSURL,
|
||||||
|
ConfigDefault.CustomHead,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return default config if nothing provided
|
// Return default config if nothing provided
|
||||||
|
@ -93,8 +105,8 @@ func configDefault(config ...Config) Config {
|
||||||
cfg.FontURL = defaultFontURL
|
cfg.FontURL = defaultFontURL
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.ChartJsURL == "" {
|
if cfg.ChartJSURL == "" {
|
||||||
cfg.ChartJsURL = defaultChartJsURL
|
cfg.ChartJSURL = defaultChartJSURL
|
||||||
}
|
}
|
||||||
if cfg.Refresh < minRefresh {
|
if cfg.Refresh < minRefresh {
|
||||||
cfg.Refresh = minRefresh
|
cfg.Refresh = minRefresh
|
||||||
|
@ -112,8 +124,8 @@ func configDefault(config ...Config) Config {
|
||||||
cfg.index = newIndex(viewBag{
|
cfg.index = newIndex(viewBag{
|
||||||
title: cfg.Title,
|
title: cfg.Title,
|
||||||
refresh: cfg.Refresh,
|
refresh: cfg.Refresh,
|
||||||
fontUrl: cfg.FontURL,
|
fontURL: cfg.FontURL,
|
||||||
chartJsUrl: cfg.ChartJsURL,
|
chartJSURL: cfg.ChartJSURL,
|
||||||
customHead: cfg.CustomHead,
|
customHead: cfg.CustomHead,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -18,11 +18,11 @@ func Test_Config_Default(t *testing.T) {
|
||||||
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
||||||
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
|
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
|
||||||
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
||||||
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL)
|
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
|
||||||
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
||||||
utils.AssertEqual(t, false, cfg.APIOnly)
|
utils.AssertEqual(t, false, cfg.APIOnly)
|
||||||
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
||||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJsURL, defaultCustomHead}), cfg.index)
|
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("set title", func(t *testing.T) {
|
t.Run("set title", func(t *testing.T) {
|
||||||
|
@ -35,11 +35,11 @@ func Test_Config_Default(t *testing.T) {
|
||||||
utils.AssertEqual(t, title, cfg.Title)
|
utils.AssertEqual(t, title, cfg.Title)
|
||||||
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
|
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
|
||||||
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
||||||
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL)
|
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
|
||||||
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
||||||
utils.AssertEqual(t, false, cfg.APIOnly)
|
utils.AssertEqual(t, false, cfg.APIOnly)
|
||||||
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
||||||
utils.AssertEqual(t, newIndex(viewBag{title, defaultRefresh, defaultFontURL, defaultChartJsURL, defaultCustomHead}), cfg.index)
|
utils.AssertEqual(t, newIndex(viewBag{title, defaultRefresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("set refresh less than default", func(t *testing.T) {
|
t.Run("set refresh less than default", func(t *testing.T) {
|
||||||
|
@ -51,11 +51,11 @@ func Test_Config_Default(t *testing.T) {
|
||||||
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
||||||
utils.AssertEqual(t, minRefresh, cfg.Refresh)
|
utils.AssertEqual(t, minRefresh, cfg.Refresh)
|
||||||
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
||||||
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL)
|
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
|
||||||
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
||||||
utils.AssertEqual(t, false, cfg.APIOnly)
|
utils.AssertEqual(t, false, cfg.APIOnly)
|
||||||
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
||||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, minRefresh, defaultFontURL, defaultChartJsURL, defaultCustomHead}), cfg.index)
|
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, minRefresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("set refresh", func(t *testing.T) {
|
t.Run("set refresh", func(t *testing.T) {
|
||||||
|
@ -68,45 +68,45 @@ func Test_Config_Default(t *testing.T) {
|
||||||
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
||||||
utils.AssertEqual(t, refresh, cfg.Refresh)
|
utils.AssertEqual(t, refresh, cfg.Refresh)
|
||||||
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
||||||
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL)
|
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
|
||||||
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
||||||
utils.AssertEqual(t, false, cfg.APIOnly)
|
utils.AssertEqual(t, false, cfg.APIOnly)
|
||||||
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
||||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, refresh, defaultFontURL, defaultChartJsURL, defaultCustomHead}), cfg.index)
|
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, refresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("set font url", func(t *testing.T) {
|
t.Run("set font url", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
fontUrl := "https://example.com"
|
fontURL := "https://example.com"
|
||||||
cfg := configDefault(Config{
|
cfg := configDefault(Config{
|
||||||
FontURL: fontUrl,
|
FontURL: fontURL,
|
||||||
})
|
})
|
||||||
|
|
||||||
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
||||||
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
|
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
|
||||||
utils.AssertEqual(t, fontUrl, cfg.FontURL)
|
utils.AssertEqual(t, fontURL, cfg.FontURL)
|
||||||
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL)
|
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
|
||||||
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
||||||
utils.AssertEqual(t, false, cfg.APIOnly)
|
utils.AssertEqual(t, false, cfg.APIOnly)
|
||||||
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
||||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, fontUrl, defaultChartJsURL, defaultCustomHead}), cfg.index)
|
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, fontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("set chart js url", func(t *testing.T) {
|
t.Run("set chart js url", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
chartUrl := "http://example.com"
|
chartURL := "http://example.com"
|
||||||
cfg := configDefault(Config{
|
cfg := configDefault(Config{
|
||||||
ChartJsURL: chartUrl,
|
ChartJSURL: chartURL,
|
||||||
})
|
})
|
||||||
|
|
||||||
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
||||||
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
|
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
|
||||||
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
||||||
utils.AssertEqual(t, chartUrl, cfg.ChartJsURL)
|
utils.AssertEqual(t, chartURL, cfg.ChartJSURL)
|
||||||
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
||||||
utils.AssertEqual(t, false, cfg.APIOnly)
|
utils.AssertEqual(t, false, cfg.APIOnly)
|
||||||
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
||||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, chartUrl, defaultCustomHead}), cfg.index)
|
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, chartURL, defaultCustomHead}), cfg.index)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("set custom head", func(t *testing.T) {
|
t.Run("set custom head", func(t *testing.T) {
|
||||||
|
@ -119,11 +119,11 @@ func Test_Config_Default(t *testing.T) {
|
||||||
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
||||||
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
|
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
|
||||||
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
||||||
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL)
|
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
|
||||||
utils.AssertEqual(t, head, cfg.CustomHead)
|
utils.AssertEqual(t, head, cfg.CustomHead)
|
||||||
utils.AssertEqual(t, false, cfg.APIOnly)
|
utils.AssertEqual(t, false, cfg.APIOnly)
|
||||||
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
||||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJsURL, head}), cfg.index)
|
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJSURL, head}), cfg.index)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("set api only", func(t *testing.T) {
|
t.Run("set api only", func(t *testing.T) {
|
||||||
|
@ -135,11 +135,11 @@ func Test_Config_Default(t *testing.T) {
|
||||||
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
||||||
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
|
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
|
||||||
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
||||||
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL)
|
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
|
||||||
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
||||||
utils.AssertEqual(t, true, cfg.APIOnly)
|
utils.AssertEqual(t, true, cfg.APIOnly)
|
||||||
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
||||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJsURL, defaultCustomHead}), cfg.index)
|
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("set next", func(t *testing.T) {
|
t.Run("set next", func(t *testing.T) {
|
||||||
|
@ -154,10 +154,10 @@ func Test_Config_Default(t *testing.T) {
|
||||||
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
||||||
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
|
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
|
||||||
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
||||||
utils.AssertEqual(t, defaultChartJsURL, cfg.ChartJsURL)
|
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJSURL)
|
||||||
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
||||||
utils.AssertEqual(t, false, cfg.APIOnly)
|
utils.AssertEqual(t, false, cfg.APIOnly)
|
||||||
utils.AssertEqual(t, f(nil), cfg.Next(nil))
|
utils.AssertEqual(t, f(nil), cfg.Next(nil))
|
||||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJsURL, defaultCustomHead}), cfg.index)
|
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,23 +9,22 @@ import (
|
||||||
type viewBag struct {
|
type viewBag struct {
|
||||||
title string
|
title string
|
||||||
refresh time.Duration
|
refresh time.Duration
|
||||||
fontUrl string
|
fontURL string
|
||||||
chartJsUrl string
|
chartJSURL string
|
||||||
customHead string
|
customHead string
|
||||||
}
|
}
|
||||||
|
|
||||||
// returns index with new title/refresh
|
// returns index with new title/refresh
|
||||||
func newIndex(dat viewBag) string {
|
func newIndex(dat viewBag) string {
|
||||||
|
|
||||||
timeout := dat.refresh.Milliseconds() - timeoutDiff
|
timeout := dat.refresh.Milliseconds() - timeoutDiff
|
||||||
if timeout < timeoutDiff {
|
if timeout < timeoutDiff {
|
||||||
timeout = timeoutDiff
|
timeout = timeoutDiff
|
||||||
}
|
}
|
||||||
ts := strconv.FormatInt(timeout, 10)
|
ts := strconv.FormatInt(timeout, 10)
|
||||||
replacer := strings.NewReplacer("$TITLE", dat.title, "$TIMEOUT", ts,
|
replacer := strings.NewReplacer("$TITLE", dat.title, "$TIMEOUT", ts,
|
||||||
"$FONT_URL", dat.fontUrl, "$CHART_JS_URL", dat.chartJsUrl, "$CUSTOM_HEAD", dat.customHead,
|
"$FONT_URL", dat.fontURL, "$CHART_JS_URL", dat.chartJSURL, "$CUSTOM_HEAD", dat.customHead,
|
||||||
)
|
)
|
||||||
return replacer.Replace(indexHtml)
|
return replacer.Replace(indexHTML)
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -35,11 +34,11 @@ const (
|
||||||
timeoutDiff = 200 // timeout will be Refresh (in milliseconds) - timeoutDiff
|
timeoutDiff = 200 // timeout will be Refresh (in milliseconds) - timeoutDiff
|
||||||
minRefresh = timeoutDiff * time.Millisecond
|
minRefresh = timeoutDiff * time.Millisecond
|
||||||
defaultFontURL = `https://fonts.googleapis.com/css2?family=Roboto:wght@400;900&display=swap`
|
defaultFontURL = `https://fonts.googleapis.com/css2?family=Roboto:wght@400;900&display=swap`
|
||||||
defaultChartJsURL = `https://cdn.jsdelivr.net/npm/chart.js@2.9/dist/Chart.bundle.min.js`
|
defaultChartJSURL = `https://cdn.jsdelivr.net/npm/chart.js@2.9/dist/Chart.bundle.min.js`
|
||||||
defaultCustomHead = ``
|
defaultCustomHead = ``
|
||||||
|
|
||||||
// parametrized by $TITLE and $TIMEOUT
|
// parametrized by $TITLE and $TIMEOUT
|
||||||
indexHtml = `<!DOCTYPE html>
|
indexHTML = `<!DOCTYPE html>
|
||||||
<html lang="en">
|
<html lang="en">
|
||||||
<head>
|
<head>
|
||||||
<meta charset="UTF-8">
|
<meta charset="UTF-8">
|
||||||
|
|
|
@ -33,18 +33,20 @@ type statsOS struct {
|
||||||
Conns int `json:"conns"`
|
Conns int `json:"conns"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//nolint:gochecknoglobals // TODO: Do not use a global var here
|
||||||
var (
|
var (
|
||||||
monitPidCpu atomic.Value
|
monitPIDCPU atomic.Value
|
||||||
monitPidRam atomic.Value
|
monitPIDRAM atomic.Value
|
||||||
monitPidConns atomic.Value
|
monitPIDConns atomic.Value
|
||||||
|
|
||||||
monitOsCpu atomic.Value
|
monitOSCPU atomic.Value
|
||||||
monitOsRam atomic.Value
|
monitOSRAM atomic.Value
|
||||||
monitOsTotalRam atomic.Value
|
monitOSTotalRAM atomic.Value
|
||||||
monitOsLoadAvg atomic.Value
|
monitOSLoadAvg atomic.Value
|
||||||
monitOsConns atomic.Value
|
monitOSConns atomic.Value
|
||||||
)
|
)
|
||||||
|
|
||||||
|
//nolint:gochecknoglobals // TODO: Do not use a global var here
|
||||||
var (
|
var (
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
once sync.Once
|
once sync.Once
|
||||||
|
@ -58,7 +60,7 @@ func New(config ...Config) fiber.Handler {
|
||||||
|
|
||||||
// Start routine to update statistics
|
// Start routine to update statistics
|
||||||
once.Do(func() {
|
once.Do(func() {
|
||||||
p, _ := process.NewProcess(int32(os.Getpid()))
|
p, _ := process.NewProcess(int32(os.Getpid())) //nolint:errcheck // TODO: Handle error
|
||||||
|
|
||||||
updateStatistics(p)
|
updateStatistics(p)
|
||||||
|
|
||||||
|
@ -72,6 +74,7 @@ func New(config ...Config) fiber.Handler {
|
||||||
})
|
})
|
||||||
|
|
||||||
// Return new handler
|
// Return new handler
|
||||||
|
//nolint:errcheck // Ignore the type-assertion errors
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
// Don't execute middleware if Next returns true
|
// Don't execute middleware if Next returns true
|
||||||
if cfg.Next != nil && cfg.Next(c) {
|
if cfg.Next != nil && cfg.Next(c) {
|
||||||
|
@ -83,15 +86,15 @@ func New(config ...Config) fiber.Handler {
|
||||||
}
|
}
|
||||||
if c.Get(fiber.HeaderAccept) == fiber.MIMEApplicationJSON || cfg.APIOnly {
|
if c.Get(fiber.HeaderAccept) == fiber.MIMEApplicationJSON || cfg.APIOnly {
|
||||||
mutex.Lock()
|
mutex.Lock()
|
||||||
data.PID.CPU = monitPidCpu.Load().(float64)
|
data.PID.CPU, _ = monitPIDCPU.Load().(float64)
|
||||||
data.PID.RAM = monitPidRam.Load().(uint64)
|
data.PID.RAM, _ = monitPIDRAM.Load().(uint64)
|
||||||
data.PID.Conns = monitPidConns.Load().(int)
|
data.PID.Conns, _ = monitPIDConns.Load().(int)
|
||||||
|
|
||||||
data.OS.CPU = monitOsCpu.Load().(float64)
|
data.OS.CPU, _ = monitOSCPU.Load().(float64)
|
||||||
data.OS.RAM = monitOsRam.Load().(uint64)
|
data.OS.RAM, _ = monitOSRAM.Load().(uint64)
|
||||||
data.OS.TotalRAM = monitOsTotalRam.Load().(uint64)
|
data.OS.TotalRAM, _ = monitOSTotalRAM.Load().(uint64)
|
||||||
data.OS.LoadAvg = monitOsLoadAvg.Load().(float64)
|
data.OS.LoadAvg, _ = monitOSLoadAvg.Load().(float64)
|
||||||
data.OS.Conns = monitOsConns.Load().(int)
|
data.OS.Conns, _ = monitOSConns.Load().(int)
|
||||||
mutex.Unlock()
|
mutex.Unlock()
|
||||||
return c.Status(fiber.StatusOK).JSON(data)
|
return c.Status(fiber.StatusOK).JSON(data)
|
||||||
}
|
}
|
||||||
|
@ -101,29 +104,35 @@ func New(config ...Config) fiber.Handler {
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateStatistics(p *process.Process) {
|
func updateStatistics(p *process.Process) {
|
||||||
pidCpu, _ := p.CPUPercent()
|
pidCPU, err := p.CPUPercent()
|
||||||
monitPidCpu.Store(pidCpu / 10)
|
if err != nil {
|
||||||
|
monitPIDCPU.Store(pidCPU / 10) //nolint:gomnd // TODO: Explain why we divide by 10 here
|
||||||
if osCpu, _ := cpu.Percent(0, false); len(osCpu) > 0 {
|
|
||||||
monitOsCpu.Store(osCpu[0])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if pidMem, _ := p.MemoryInfo(); pidMem != nil {
|
if osCPU, err := cpu.Percent(0, false); err != nil && len(osCPU) > 0 {
|
||||||
monitPidRam.Store(pidMem.RSS)
|
monitOSCPU.Store(osCPU[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
if osMem, _ := mem.VirtualMemory(); osMem != nil {
|
if pidRAM, err := p.MemoryInfo(); err != nil && pidRAM != nil {
|
||||||
monitOsRam.Store(osMem.Used)
|
monitPIDRAM.Store(pidRAM.RSS)
|
||||||
monitOsTotalRam.Store(osMem.Total)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if loadAvg, _ := load.Avg(); loadAvg != nil {
|
if osRAM, err := mem.VirtualMemory(); err != nil && osRAM != nil {
|
||||||
monitOsLoadAvg.Store(loadAvg.Load1)
|
monitOSRAM.Store(osRAM.Used)
|
||||||
|
monitOSTotalRAM.Store(osRAM.Total)
|
||||||
}
|
}
|
||||||
|
|
||||||
pidConns, _ := net.ConnectionsPid("tcp", p.Pid)
|
if loadAvg, err := load.Avg(); err != nil && loadAvg != nil {
|
||||||
monitPidConns.Store(len(pidConns))
|
monitOSLoadAvg.Store(loadAvg.Load1)
|
||||||
|
}
|
||||||
osConns, _ := net.Connections("tcp")
|
|
||||||
monitOsConns.Store(len(osConns))
|
pidConns, err := net.ConnectionsPid("tcp", p.Pid)
|
||||||
|
if err != nil {
|
||||||
|
monitPIDConns.Store(len(pidConns))
|
||||||
|
}
|
||||||
|
|
||||||
|
osConns, err := net.Connections("tcp")
|
||||||
|
if err != nil {
|
||||||
|
monitOSConns.Store(len(osConns))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/gofiber/fiber/v2/utils"
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -61,6 +62,7 @@ func Test_Monitor_Html(t *testing.T) {
|
||||||
conf.Refresh.Milliseconds()-timeoutDiff)
|
conf.Refresh.Milliseconds()-timeoutDiff)
|
||||||
utils.AssertEqual(t, true, bytes.Contains(buf, []byte(timeoutLine)))
|
utils.AssertEqual(t, true, bytes.Contains(buf, []byte(timeoutLine)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_Monitor_Html_CustomCodes(t *testing.T) {
|
func Test_Monitor_Html_CustomCodes(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
@ -82,8 +84,10 @@ func Test_Monitor_Html_CustomCodes(t *testing.T) {
|
||||||
utils.AssertEqual(t, true, bytes.Contains(buf, []byte(timeoutLine)))
|
utils.AssertEqual(t, true, bytes.Contains(buf, []byte(timeoutLine)))
|
||||||
|
|
||||||
// custom config
|
// custom config
|
||||||
conf := Config{Title: "New " + defaultTitle, Refresh: defaultRefresh + time.Second,
|
conf := Config{
|
||||||
ChartJsURL: "https://cdnjs.com/libraries/Chart.js",
|
Title: "New " + defaultTitle,
|
||||||
|
Refresh: defaultRefresh + time.Second,
|
||||||
|
ChartJSURL: "https://cdnjs.com/libraries/Chart.js",
|
||||||
FontURL: "/public/my-font.css",
|
FontURL: "/public/my-font.css",
|
||||||
CustomHead: `<style>body{background:#fff}</style>`,
|
CustomHead: `<style>body{background:#fff}</style>`,
|
||||||
}
|
}
|
||||||
|
@ -136,7 +140,7 @@ func Benchmark_Monitor(b *testing.B) {
|
||||||
h := app.Handler()
|
h := app.Handler()
|
||||||
|
|
||||||
fctx := &fasthttp.RequestCtx{}
|
fctx := &fasthttp.RequestCtx{}
|
||||||
fctx.Request.Header.SetMethod("GET")
|
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||||
fctx.Request.SetRequestURI("/")
|
fctx.Request.SetRequestURI("/")
|
||||||
fctx.Request.Header.Set(fiber.HeaderAccept, fiber.MIMEApplicationJSON)
|
fctx.Request.Header.Set(fiber.HeaderAccept, fiber.MIMEApplicationJSON)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package pprof
|
package pprof
|
||||||
|
|
||||||
import "github.com/gofiber/fiber/v2"
|
import (
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
)
|
||||||
|
|
||||||
// Config defines the config for middleware.
|
// Config defines the config for middleware.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
|
@ -17,6 +19,7 @@ type Config struct {
|
||||||
Prefix string
|
Prefix string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//nolint:gochecknoglobals // Using a global var is fine here
|
||||||
var ConfigDefault = Config{
|
var ConfigDefault = Config{
|
||||||
Next: nil,
|
Next: nil,
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,9 +5,15 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp/fasthttpadaptor"
|
"github.com/valyala/fasthttp/fasthttpadaptor"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// New creates a new middleware handler
|
||||||
|
func New(config ...Config) fiber.Handler {
|
||||||
|
// Set default config
|
||||||
|
cfg := configDefault(config...)
|
||||||
|
|
||||||
// Set pprof adaptors
|
// Set pprof adaptors
|
||||||
var (
|
var (
|
||||||
pprofIndex = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Index)
|
pprofIndex = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Index)
|
||||||
|
@ -23,11 +29,6 @@ var (
|
||||||
pprofThreadcreate = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("threadcreate").ServeHTTP)
|
pprofThreadcreate = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("threadcreate").ServeHTTP)
|
||||||
)
|
)
|
||||||
|
|
||||||
// New creates a new middleware handler
|
|
||||||
func New(config ...Config) fiber.Handler {
|
|
||||||
// Set default config
|
|
||||||
cfg := configDefault(config...)
|
|
||||||
|
|
||||||
// Return new handler
|
// Return new handler
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
// Don't execute middleware if Next returns true
|
// Don't execute middleware if Next returns true
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -48,7 +49,7 @@ type Config struct {
|
||||||
WriteBufferSize int
|
WriteBufferSize int
|
||||||
|
|
||||||
// tls config for the http client.
|
// tls config for the http client.
|
||||||
TlsConfig *tls.Config
|
TlsConfig *tls.Config //nolint:stylecheck,revive // TODO: Rename to "TLSConfig" in v3
|
||||||
|
|
||||||
// Client is custom client when client config is complex.
|
// Client is custom client when client config is complex.
|
||||||
// Note that Servers, Timeout, WriteBufferSize, ReadBufferSize and TlsConfig
|
// Note that Servers, Timeout, WriteBufferSize, ReadBufferSize and TlsConfig
|
||||||
|
@ -57,6 +58,8 @@ type Config struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConfigDefault is the default config
|
// ConfigDefault is the default config
|
||||||
|
//
|
||||||
|
//nolint:gochecknoglobals // Using a global var is fine here
|
||||||
var ConfigDefault = Config{
|
var ConfigDefault = Config{
|
||||||
Next: nil,
|
Next: nil,
|
||||||
ModifyRequest: nil,
|
ModifyRequest: nil,
|
||||||
|
|
|
@ -3,19 +3,20 @@ package proxy
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"log"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/gofiber/fiber/v2/utils"
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// New is deprecated
|
// New is deprecated
|
||||||
func New(config Config) fiber.Handler {
|
func New(config Config) fiber.Handler {
|
||||||
fmt.Println("proxy.New is deprecated, please use proxy.Balancer instead")
|
log.Printf("proxy.New is deprecated, please use proxy.Balancer instead\n")
|
||||||
return Balancer(config)
|
return Balancer(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -25,7 +26,7 @@ func Balancer(config Config) fiber.Handler {
|
||||||
cfg := configDefault(config)
|
cfg := configDefault(config)
|
||||||
|
|
||||||
// Load balanced client
|
// Load balanced client
|
||||||
var lbc = &fasthttp.LBClient{}
|
lbc := &fasthttp.LBClient{}
|
||||||
// Note that Servers, Timeout, WriteBufferSize, ReadBufferSize and TlsConfig
|
// Note that Servers, Timeout, WriteBufferSize, ReadBufferSize and TlsConfig
|
||||||
// will not be used if the client are set.
|
// will not be used if the client are set.
|
||||||
if config.Client == nil {
|
if config.Client == nil {
|
||||||
|
@ -61,7 +62,7 @@ func Balancer(config Config) fiber.Handler {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return new handler
|
// Return new handler
|
||||||
return func(c *fiber.Ctx) (err error) {
|
return func(c *fiber.Ctx) error {
|
||||||
// Don't execute middleware if Next returns true
|
// Don't execute middleware if Next returns true
|
||||||
if cfg.Next != nil && cfg.Next(c) {
|
if cfg.Next != nil && cfg.Next(c) {
|
||||||
return c.Next()
|
return c.Next()
|
||||||
|
@ -76,7 +77,7 @@ func Balancer(config Config) fiber.Handler {
|
||||||
|
|
||||||
// Modify request
|
// Modify request
|
||||||
if cfg.ModifyRequest != nil {
|
if cfg.ModifyRequest != nil {
|
||||||
if err = cfg.ModifyRequest(c); err != nil {
|
if err := cfg.ModifyRequest(c); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -84,7 +85,7 @@ func Balancer(config Config) fiber.Handler {
|
||||||
req.SetRequestURI(utils.UnsafeString(req.RequestURI()))
|
req.SetRequestURI(utils.UnsafeString(req.RequestURI()))
|
||||||
|
|
||||||
// Forward request
|
// Forward request
|
||||||
if err = lbc.Do(req, res); err != nil {
|
if err := lbc.Do(req, res); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -93,7 +94,7 @@ func Balancer(config Config) fiber.Handler {
|
||||||
|
|
||||||
// Modify response
|
// Modify response
|
||||||
if cfg.ModifyResponse != nil {
|
if cfg.ModifyResponse != nil {
|
||||||
if err = cfg.ModifyResponse(c); err != nil {
|
if err := cfg.ModifyResponse(c); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -103,16 +104,20 @@ func Balancer(config Config) fiber.Handler {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//nolint:gochecknoglobals // TODO: Do not use a global var here
|
||||||
var client = &fasthttp.Client{
|
var client = &fasthttp.Client{
|
||||||
NoDefaultUserAgentHeader: true,
|
NoDefaultUserAgentHeader: true,
|
||||||
DisablePathNormalizing: true,
|
DisablePathNormalizing: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//nolint:gochecknoglobals // TODO: Do not use a global var here
|
||||||
var lock sync.RWMutex
|
var lock sync.RWMutex
|
||||||
|
|
||||||
// WithTlsConfig update http client with a user specified tls.config
|
// WithTlsConfig update http client with a user specified tls.config
|
||||||
// This function should be called before Do and Forward.
|
// This function should be called before Do and Forward.
|
||||||
// Deprecated: use WithClient instead.
|
// Deprecated: use WithClient instead.
|
||||||
|
//
|
||||||
|
//nolint:stylecheck,revive // TODO: Rename to "WithTLSConfig" in v3
|
||||||
func WithTlsConfig(tlsConfig *tls.Config) {
|
func WithTlsConfig(tlsConfig *tls.Config) {
|
||||||
client.TLSConfig = tlsConfig
|
client.TLSConfig = tlsConfig
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -13,10 +12,11 @@ import (
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/gofiber/fiber/v2/internal/tlstest"
|
"github.com/gofiber/fiber/v2/internal/tlstest"
|
||||||
"github.com/gofiber/fiber/v2/utils"
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
func createProxyTestServer(handler fiber.Handler, t *testing.T) (*fiber.App, string) {
|
func createProxyTestServer(t *testing.T, handler fiber.Handler) (*fiber.App, string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
target := fiber.New(fiber.Config{DisableStartupMessage: true})
|
target := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||||
|
@ -60,7 +60,7 @@ func Test_Proxy_Next(t *testing.T) {
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
@ -69,11 +69,11 @@ func Test_Proxy_Next(t *testing.T) {
|
||||||
func Test_Proxy(t *testing.T) {
|
func Test_Proxy(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
target, addr := createProxyTestServer(
|
target, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||||
func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }, t,
|
return c.SendStatus(fiber.StatusTeapot)
|
||||||
)
|
})
|
||||||
|
|
||||||
resp, err := target.Test(httptest.NewRequest("GET", "/", nil), 2000)
|
resp, err := target.Test(httptest.NewRequest(fiber.MethodGet, "/", nil), 2000)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
|
||||||
|
|
||||||
|
@ -81,7 +81,7 @@ func Test_Proxy(t *testing.T) {
|
||||||
|
|
||||||
app.Use(Balancer(Config{Servers: []string{addr}}))
|
app.Use(Balancer(Config{Servers: []string{addr}}))
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||||
req.Host = addr
|
req.Host = addr
|
||||||
resp, err = app.Test(req)
|
resp, err = app.Test(req)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
@ -107,7 +107,7 @@ func Test_Proxy_Balancer_WithTlsConfig(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
addr := ln.Addr().String()
|
addr := ln.Addr().String()
|
||||||
clientTLSConf := &tls.Config{InsecureSkipVerify: true}
|
clientTLSConf := &tls.Config{InsecureSkipVerify: true} //nolint:gosec // We're in a test func, so this is fine
|
||||||
|
|
||||||
// disable certificate verification in Balancer
|
// disable certificate verification in Balancer
|
||||||
app.Use(Balancer(Config{
|
app.Use(Balancer(Config{
|
||||||
|
@ -128,9 +128,9 @@ func Test_Proxy_Balancer_WithTlsConfig(t *testing.T) {
|
||||||
func Test_Proxy_Forward_WithTlsConfig_To_Http(t *testing.T) {
|
func Test_Proxy_Forward_WithTlsConfig_To_Http(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
_, targetAddr := createProxyTestServer(func(c *fiber.Ctx) error {
|
_, targetAddr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||||
return c.SendString("hello from target")
|
return c.SendString("hello from target")
|
||||||
}, t)
|
})
|
||||||
|
|
||||||
proxyServerTLSConf, _, err := tlstest.GetTLSConfigs()
|
proxyServerTLSConf, _, err := tlstest.GetTLSConfigs()
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
@ -164,13 +164,13 @@ func Test_Proxy_Forward(t *testing.T) {
|
||||||
|
|
||||||
app := fiber.New()
|
app := fiber.New()
|
||||||
|
|
||||||
_, addr := createProxyTestServer(
|
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||||
func(c *fiber.Ctx) error { return c.SendString("forwarded") }, t,
|
return c.SendString("forwarded")
|
||||||
)
|
})
|
||||||
|
|
||||||
app.Use(Forward("http://" + addr))
|
app.Use(Forward("http://" + addr))
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
@ -198,7 +198,7 @@ func Test_Proxy_Forward_WithTlsConfig(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
addr := ln.Addr().String()
|
addr := ln.Addr().String()
|
||||||
clientTLSConf := &tls.Config{InsecureSkipVerify: true}
|
clientTLSConf := &tls.Config{InsecureSkipVerify: true} //nolint:gosec // We're in a test func, so this is fine
|
||||||
|
|
||||||
// disable certificate verification
|
// disable certificate verification
|
||||||
WithTlsConfig(clientTLSConf)
|
WithTlsConfig(clientTLSConf)
|
||||||
|
@ -217,9 +217,9 @@ func Test_Proxy_Forward_WithTlsConfig(t *testing.T) {
|
||||||
func Test_Proxy_Modify_Response(t *testing.T) {
|
func Test_Proxy_Modify_Response(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
_, addr := createProxyTestServer(func(c *fiber.Ctx) error {
|
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||||
return c.Status(500).SendString("not modified")
|
return c.Status(500).SendString("not modified")
|
||||||
}, t)
|
})
|
||||||
|
|
||||||
app := fiber.New()
|
app := fiber.New()
|
||||||
app.Use(Balancer(Config{
|
app.Use(Balancer(Config{
|
||||||
|
@ -230,7 +230,7 @@ func Test_Proxy_Modify_Response(t *testing.T) {
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
@ -243,10 +243,10 @@ func Test_Proxy_Modify_Response(t *testing.T) {
|
||||||
func Test_Proxy_Modify_Request(t *testing.T) {
|
func Test_Proxy_Modify_Request(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
_, addr := createProxyTestServer(func(c *fiber.Ctx) error {
|
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||||
b := c.Request().Body()
|
b := c.Request().Body()
|
||||||
return c.SendString(string(b))
|
return c.SendString(string(b))
|
||||||
}, t)
|
})
|
||||||
|
|
||||||
app := fiber.New()
|
app := fiber.New()
|
||||||
app.Use(Balancer(Config{
|
app.Use(Balancer(Config{
|
||||||
|
@ -257,7 +257,7 @@ func Test_Proxy_Modify_Request(t *testing.T) {
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
@ -270,10 +270,10 @@ func Test_Proxy_Modify_Request(t *testing.T) {
|
||||||
func Test_Proxy_Timeout_Slow_Server(t *testing.T) {
|
func Test_Proxy_Timeout_Slow_Server(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
_, addr := createProxyTestServer(func(c *fiber.Ctx) error {
|
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||||
time.Sleep(2 * time.Second)
|
time.Sleep(2 * time.Second)
|
||||||
return c.SendString("fiber is awesome")
|
return c.SendString("fiber is awesome")
|
||||||
}, t)
|
})
|
||||||
|
|
||||||
app := fiber.New()
|
app := fiber.New()
|
||||||
app.Use(Balancer(Config{
|
app.Use(Balancer(Config{
|
||||||
|
@ -281,7 +281,7 @@ func Test_Proxy_Timeout_Slow_Server(t *testing.T) {
|
||||||
Timeout: 3 * time.Second,
|
Timeout: 3 * time.Second,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil), 5000)
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil), 5000)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
@ -294,10 +294,10 @@ func Test_Proxy_Timeout_Slow_Server(t *testing.T) {
|
||||||
func Test_Proxy_With_Timeout(t *testing.T) {
|
func Test_Proxy_With_Timeout(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
_, addr := createProxyTestServer(func(c *fiber.Ctx) error {
|
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(1 * time.Second)
|
||||||
return c.SendString("fiber is awesome")
|
return c.SendString("fiber is awesome")
|
||||||
}, t)
|
})
|
||||||
|
|
||||||
app := fiber.New()
|
app := fiber.New()
|
||||||
app.Use(Balancer(Config{
|
app.Use(Balancer(Config{
|
||||||
|
@ -305,7 +305,7 @@ func Test_Proxy_With_Timeout(t *testing.T) {
|
||||||
Timeout: 100 * time.Millisecond,
|
Timeout: 100 * time.Millisecond,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil), 2000)
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil), 2000)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
|
||||||
|
|
||||||
|
@ -318,16 +318,16 @@ func Test_Proxy_With_Timeout(t *testing.T) {
|
||||||
func Test_Proxy_Buffer_Size_Response(t *testing.T) {
|
func Test_Proxy_Buffer_Size_Response(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
_, addr := createProxyTestServer(func(c *fiber.Ctx) error {
|
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||||
long := strings.Join(make([]string, 5000), "-")
|
long := strings.Join(make([]string, 5000), "-")
|
||||||
c.Set("Very-Long-Header", long)
|
c.Set("Very-Long-Header", long)
|
||||||
return c.SendString("ok")
|
return c.SendString("ok")
|
||||||
}, t)
|
})
|
||||||
|
|
||||||
app := fiber.New()
|
app := fiber.New()
|
||||||
app.Use(Balancer(Config{Servers: []string{addr}}))
|
app.Use(Balancer(Config{Servers: []string{addr}}))
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
|
||||||
|
|
||||||
|
@ -337,7 +337,7 @@ func Test_Proxy_Buffer_Size_Response(t *testing.T) {
|
||||||
ReadBufferSize: 1024 * 8,
|
ReadBufferSize: 1024 * 8,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
resp, err = app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
@ -357,9 +357,9 @@ func Test_Proxy_Do_RestoreOriginalURL(t *testing.T) {
|
||||||
utils.AssertEqual(t, originalURL, c.OriginalURL())
|
utils.AssertEqual(t, originalURL, c.OriginalURL())
|
||||||
return c.SendString("ok")
|
return c.SendString("ok")
|
||||||
})
|
})
|
||||||
_, err1 := app.Test(httptest.NewRequest("GET", "/test", nil))
|
_, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
|
||||||
// This test requires multiple requests due to zero allocation used in fiber
|
// This test requires multiple requests due to zero allocation used in fiber
|
||||||
_, err2 := app.Test(httptest.NewRequest("GET", "/test", nil))
|
_, err2 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
|
||||||
|
|
||||||
utils.AssertEqual(t, nil, err1)
|
utils.AssertEqual(t, nil, err1)
|
||||||
utils.AssertEqual(t, nil, err2)
|
utils.AssertEqual(t, nil, err2)
|
||||||
|
@ -369,9 +369,9 @@ func Test_Proxy_Do_RestoreOriginalURL(t *testing.T) {
|
||||||
func Test_Proxy_Do_HTTP_Prefix_URL(t *testing.T) {
|
func Test_Proxy_Do_HTTP_Prefix_URL(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
_, addr := createProxyTestServer(func(c *fiber.Ctx) error {
|
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||||
return c.SendString("hello world")
|
return c.SendString("hello world")
|
||||||
}, t)
|
})
|
||||||
|
|
||||||
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||||
app.Get("/*", func(c *fiber.Ctx) error {
|
app.Get("/*", func(c *fiber.Ctx) error {
|
||||||
|
@ -386,7 +386,7 @@ func Test_Proxy_Do_HTTP_Prefix_URL(t *testing.T) {
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/http://"+addr, nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/http://"+addr, nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
s, err := io.ReadAll(resp.Body)
|
s, err := io.ReadAll(resp.Body)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
@ -431,9 +431,8 @@ func Test_Proxy_Forward_Local_Client(t *testing.T) {
|
||||||
app.Use(Forward("http://"+addr+"/test_local_client", &fasthttp.Client{
|
app.Use(Forward("http://"+addr+"/test_local_client", &fasthttp.Client{
|
||||||
NoDefaultUserAgentHeader: true,
|
NoDefaultUserAgentHeader: true,
|
||||||
DisablePathNormalizing: true,
|
DisablePathNormalizing: true,
|
||||||
Dial: func(addr string) (net.Conn, error) {
|
|
||||||
return fasthttp.Dial(addr)
|
Dial: fasthttp.Dial,
|
||||||
},
|
|
||||||
}))
|
}))
|
||||||
go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }()
|
go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }()
|
||||||
|
|
||||||
|
@ -447,11 +446,11 @@ func Test_Proxy_Forward_Local_Client(t *testing.T) {
|
||||||
func Test_ProxyBalancer_Custom_Client(t *testing.T) {
|
func Test_ProxyBalancer_Custom_Client(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
target, addr := createProxyTestServer(
|
target, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||||
func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }, t,
|
return c.SendStatus(fiber.StatusTeapot)
|
||||||
)
|
})
|
||||||
|
|
||||||
resp, err := target.Test(httptest.NewRequest("GET", "/", nil), 2000)
|
resp, err := target.Test(httptest.NewRequest(fiber.MethodGet, "/", nil), 2000)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
|
||||||
|
|
||||||
|
@ -468,7 +467,7 @@ func Test_ProxyBalancer_Custom_Client(t *testing.T) {
|
||||||
Timeout: time.Second,
|
Timeout: time.Second,
|
||||||
}}))
|
}}))
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||||
req.Host = addr
|
req.Host = addr
|
||||||
resp, err = app.Test(req)
|
resp, err = app.Test(req)
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package recover
|
package recover //nolint:predeclared // TODO: Rename to some non-builtin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
|
@ -23,6 +23,8 @@ type Config struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConfigDefault is the default config
|
// ConfigDefault is the default config
|
||||||
|
//
|
||||||
|
//nolint:gochecknoglobals // Using a global var is fine here
|
||||||
var ConfigDefault = Config{
|
var ConfigDefault = Config{
|
||||||
Next: nil,
|
Next: nil,
|
||||||
EnableStackTrace: false,
|
EnableStackTrace: false,
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package recover
|
package recover //nolint:predeclared // TODO: Rename to some non-builtin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -9,7 +9,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func defaultStackTraceHandler(_ *fiber.Ctx, e interface{}) {
|
func defaultStackTraceHandler(_ *fiber.Ctx, e interface{}) {
|
||||||
_, _ = os.Stderr.WriteString(fmt.Sprintf("panic: %v\n%s\n", e, debug.Stack()))
|
_, _ = os.Stderr.WriteString(fmt.Sprintf("panic: %v\n%s\n", e, debug.Stack())) //nolint:errcheck // This will never fail
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new middleware handler
|
// New creates a new middleware handler
|
||||||
|
@ -18,7 +18,7 @@ func New(config ...Config) fiber.Handler {
|
||||||
cfg := configDefault(config...)
|
cfg := configDefault(config...)
|
||||||
|
|
||||||
// Return new handler
|
// Return new handler
|
||||||
return func(c *fiber.Ctx) (err error) {
|
return func(c *fiber.Ctx) (err error) { //nolint:nonamedreturns // Uses recover() to overwrite the error
|
||||||
// Don't execute middleware if Next returns true
|
// Don't execute middleware if Next returns true
|
||||||
if cfg.Next != nil && cfg.Next(c) {
|
if cfg.Next != nil && cfg.Next(c) {
|
||||||
return c.Next()
|
return c.Next()
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package recover
|
package recover //nolint:predeclared // TODO: Rename to some non-builtin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
@ -24,7 +24,7 @@ func Test_Recover(t *testing.T) {
|
||||||
panic("Hi, I'm an error!")
|
panic("Hi, I'm an error!")
|
||||||
})
|
})
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/panic", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/panic", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
@ -39,7 +39,7 @@ func Test_Recover_Next(t *testing.T) {
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
@ -55,7 +55,7 @@ func Test_Recover_EnableStackTrace(t *testing.T) {
|
||||||
panic("Hi, I'm an error!")
|
panic("Hi, I'm an error!")
|
||||||
})
|
})
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/panic", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/panic", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,6 +33,8 @@ type Config struct {
|
||||||
// It uses a fast UUID generator which will expose the number of
|
// It uses a fast UUID generator which will expose the number of
|
||||||
// requests made to the server. To conceal this value for better
|
// requests made to the server. To conceal this value for better
|
||||||
// privacy, use the "utils.UUIDv4" generator.
|
// privacy, use the "utils.UUIDv4" generator.
|
||||||
|
//
|
||||||
|
//nolint:gochecknoglobals // Using a global var is fine here
|
||||||
var ConfigDefault = Config{
|
var ConfigDefault = Config{
|
||||||
Next: nil,
|
Next: nil,
|
||||||
Header: fiber.HeaderXRequestID,
|
Header: fiber.HeaderXRequestID,
|
||||||
|
|
|
@ -19,14 +19,14 @@ func Test_RequestID(t *testing.T) {
|
||||||
return c.SendString("Hello, World 👋!")
|
return c.SendString("Hello, World 👋!")
|
||||||
})
|
})
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
reqid := resp.Header.Get(fiber.HeaderXRequestID)
|
reqid := resp.Header.Get(fiber.HeaderXRequestID)
|
||||||
utils.AssertEqual(t, 36, len(reqid))
|
utils.AssertEqual(t, 36, len(reqid))
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||||
req.Header.Add(fiber.HeaderXRequestID, reqid)
|
req.Header.Add(fiber.HeaderXRequestID, reqid)
|
||||||
|
|
||||||
resp, err = app.Test(req)
|
resp, err = app.Test(req)
|
||||||
|
@ -45,7 +45,7 @@ func Test_RequestID_Next(t *testing.T) {
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, resp.Header.Get(fiber.HeaderXRequestID), "")
|
utils.AssertEqual(t, resp.Header.Get(fiber.HeaderXRequestID), "")
|
||||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||||
|
@ -54,13 +54,13 @@ func Test_RequestID_Next(t *testing.T) {
|
||||||
// go test -run Test_RequestID_Locals
|
// go test -run Test_RequestID_Locals
|
||||||
func Test_RequestID_Locals(t *testing.T) {
|
func Test_RequestID_Locals(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
reqId := "ThisIsARequestId"
|
reqID := "ThisIsARequestId"
|
||||||
ctxKey := "ThisIsAContextKey"
|
ctxKey := "ThisIsAContextKey"
|
||||||
|
|
||||||
app := fiber.New()
|
app := fiber.New()
|
||||||
app.Use(New(Config{
|
app.Use(New(Config{
|
||||||
Generator: func() string {
|
Generator: func() string {
|
||||||
return reqId
|
return reqID
|
||||||
},
|
},
|
||||||
ContextKey: ctxKey,
|
ContextKey: ctxKey,
|
||||||
}))
|
}))
|
||||||
|
@ -68,11 +68,11 @@ func Test_RequestID_Locals(t *testing.T) {
|
||||||
var ctxVal string
|
var ctxVal string
|
||||||
|
|
||||||
app.Use(func(c *fiber.Ctx) error {
|
app.Use(func(c *fiber.Ctx) error {
|
||||||
ctxVal = c.Locals(ctxKey).(string)
|
ctxVal = c.Locals(ctxKey).(string) //nolint:forcetypeassert,errcheck // We always store a string in here
|
||||||
return c.Next()
|
return c.Next()
|
||||||
})
|
})
|
||||||
|
|
||||||
_, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, reqId, ctxVal)
|
utils.AssertEqual(t, reqID, ctxVal)
|
||||||
}
|
}
|
||||||
|
|
|
@ -148,7 +148,7 @@ type Config struct {
|
||||||
// Optional. Default value utils.UUID
|
// Optional. Default value utils.UUID
|
||||||
KeyGenerator func() string
|
KeyGenerator func() string
|
||||||
|
|
||||||
// Deprecated, please use KeyLookup
|
// Deprecated: Please use KeyLookup
|
||||||
CookieName string
|
CookieName string
|
||||||
|
|
||||||
// Source defines where to obtain the session id
|
// Source defines where to obtain the session id
|
||||||
|
|
|
@ -2,6 +2,7 @@ package session
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -49,7 +50,7 @@ type Config struct {
|
||||||
// Optional. Default value utils.UUIDv4
|
// Optional. Default value utils.UUIDv4
|
||||||
KeyGenerator func() string
|
KeyGenerator func() string
|
||||||
|
|
||||||
// Deprecated, please use KeyLookup
|
// Deprecated: Please use KeyLookup
|
||||||
CookieName string
|
CookieName string
|
||||||
|
|
||||||
// Source defines where to obtain the session id
|
// Source defines where to obtain the session id
|
||||||
|
@ -68,8 +69,10 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConfigDefault is the default config
|
// ConfigDefault is the default config
|
||||||
|
//
|
||||||
|
//nolint:gochecknoglobals // Using a global var is fine here
|
||||||
var ConfigDefault = Config{
|
var ConfigDefault = Config{
|
||||||
Expiration: 24 * time.Hour,
|
Expiration: 24 * time.Hour, //nolint:gomnd // No magic number, just the default config
|
||||||
KeyLookup: "cookie:session_id",
|
KeyLookup: "cookie:session_id",
|
||||||
KeyGenerator: utils.UUIDv4,
|
KeyGenerator: utils.UUIDv4,
|
||||||
source: "cookie",
|
source: "cookie",
|
||||||
|
@ -91,7 +94,7 @@ func configDefault(config ...Config) Config {
|
||||||
cfg.Expiration = ConfigDefault.Expiration
|
cfg.Expiration = ConfigDefault.Expiration
|
||||||
}
|
}
|
||||||
if cfg.CookieName != "" {
|
if cfg.CookieName != "" {
|
||||||
fmt.Println("[session] CookieName is deprecated, please use KeyLookup")
|
log.Printf("[session] CookieName is deprecated, please use KeyLookup\n")
|
||||||
cfg.KeyLookup = fmt.Sprintf("cookie:%s", cfg.CookieName)
|
cfg.KeyLookup = fmt.Sprintf("cookie:%s", cfg.CookieName)
|
||||||
}
|
}
|
||||||
if cfg.KeyLookup == "" {
|
if cfg.KeyLookup == "" {
|
||||||
|
@ -102,7 +105,8 @@ func configDefault(config ...Config) Config {
|
||||||
}
|
}
|
||||||
|
|
||||||
selectors := strings.Split(cfg.KeyLookup, ":")
|
selectors := strings.Split(cfg.KeyLookup, ":")
|
||||||
if len(selectors) != 2 {
|
const numSelectors = 2
|
||||||
|
if len(selectors) != numSelectors {
|
||||||
panic("[session] KeyLookup must in the form of <source>:<name>")
|
panic("[session] KeyLookup must in the form of <source>:<name>")
|
||||||
}
|
}
|
||||||
switch Source(selectors[0]) {
|
switch Source(selectors[0]) {
|
||||||
|
|
|
@ -13,6 +13,7 @@ type data struct {
|
||||||
Data map[string]interface{}
|
Data map[string]interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//nolint:gochecknoglobals // TODO: Do not use a global var here
|
||||||
var dataPool = sync.Pool{
|
var dataPool = sync.Pool{
|
||||||
New: func() interface{} {
|
New: func() interface{} {
|
||||||
d := new(data)
|
d := new(data)
|
||||||
|
@ -22,7 +23,7 @@ var dataPool = sync.Pool{
|
||||||
}
|
}
|
||||||
|
|
||||||
func acquireData() *data {
|
func acquireData() *data {
|
||||||
return dataPool.Get().(*data)
|
return dataPool.Get().(*data) //nolint:forcetypeassert // We store nothing else in the pool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *data) Reset() {
|
func (d *data) Reset() {
|
||||||
|
|
|
@ -3,11 +3,13 @@ package session
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/gob"
|
"encoding/gob"
|
||||||
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/gofiber/fiber/v2/utils"
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -21,6 +23,7 @@ type Session struct {
|
||||||
exp time.Duration // expiration of this session
|
exp time.Duration // expiration of this session
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//nolint:gochecknoglobals // TODO: Do not use a global var here
|
||||||
var sessionPool = sync.Pool{
|
var sessionPool = sync.Pool{
|
||||||
New: func() interface{} {
|
New: func() interface{} {
|
||||||
return new(Session)
|
return new(Session)
|
||||||
|
@ -28,7 +31,7 @@ var sessionPool = sync.Pool{
|
||||||
}
|
}
|
||||||
|
|
||||||
func acquireSession() *Session {
|
func acquireSession() *Session {
|
||||||
s := sessionPool.Get().(*Session)
|
s := sessionPool.Get().(*Session) //nolint:forcetypeassert,errcheck // We store nothing else in the pool
|
||||||
if s.data == nil {
|
if s.data == nil {
|
||||||
s.data = acquireData()
|
s.data = acquireData()
|
||||||
}
|
}
|
||||||
|
@ -153,7 +156,7 @@ func (s *Session) Save() error {
|
||||||
encCache := gob.NewEncoder(s.byteBuffer)
|
encCache := gob.NewEncoder(s.byteBuffer)
|
||||||
err := encCache.Encode(&s.data.Data)
|
err := encCache.Encode(&s.data.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to encode data: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// copy the data in buffer
|
// copy the data in buffer
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
||||||
"github.com/gofiber/fiber/v2/utils"
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -94,6 +95,8 @@ func Test_Session(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// go test -run Test_Session_Types
|
// go test -run Test_Session_Types
|
||||||
|
//
|
||||||
|
//nolint:forcetypeassert // TODO: Do not force-type assert
|
||||||
func Test_Session_Types(t *testing.T) {
|
func Test_Session_Types(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
@ -127,25 +130,27 @@ func Test_Session_Types(t *testing.T) {
|
||||||
Name: "John",
|
Name: "John",
|
||||||
}
|
}
|
||||||
// set value
|
// set value
|
||||||
var vbool = true
|
var (
|
||||||
var vstring = "str"
|
vbool = true
|
||||||
var vint = 13
|
vstring = "str"
|
||||||
var vint8 int8 = 13
|
vint = 13
|
||||||
var vint16 int16 = 13
|
vint8 int8 = 13
|
||||||
var vint32 int32 = 13
|
vint16 int16 = 13
|
||||||
var vint64 int64 = 13
|
vint32 int32 = 13
|
||||||
var vuint uint = 13
|
vint64 int64 = 13
|
||||||
var vuint8 uint8 = 13
|
vuint uint = 13
|
||||||
var vuint16 uint16 = 13
|
vuint8 uint8 = 13
|
||||||
var vuint32 uint32 = 13
|
vuint16 uint16 = 13
|
||||||
var vuint64 uint64 = 13
|
vuint32 uint32 = 13
|
||||||
var vuintptr uintptr = 13
|
vuint64 uint64 = 13
|
||||||
var vbyte byte = 'k'
|
vuintptr uintptr = 13
|
||||||
var vrune rune = 'k'
|
vbyte byte = 'k'
|
||||||
var vfloat32 float32 = 13
|
vrune = 'k'
|
||||||
var vfloat64 float64 = 13
|
vfloat32 float32 = 13
|
||||||
var vcomplex64 complex64 = 13
|
vfloat64 float64 = 13
|
||||||
var vcomplex128 complex128 = 13
|
vcomplex64 complex64 = 13
|
||||||
|
vcomplex128 complex128 = 13
|
||||||
|
)
|
||||||
sess.Set("vuser", vuser)
|
sess.Set("vuser", vuser)
|
||||||
sess.Set("vbool", vbool)
|
sess.Set("vbool", vbool)
|
||||||
sess.Set("vstring", vstring)
|
sess.Set("vstring", vstring)
|
||||||
|
@ -212,7 +217,8 @@ func Test_Session_Store_Reset(t *testing.T) {
|
||||||
defer app.ReleaseCtx(ctx)
|
defer app.ReleaseCtx(ctx)
|
||||||
|
|
||||||
// get session
|
// get session
|
||||||
sess, _ := store.Get(ctx)
|
sess, err := store.Get(ctx)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
// make sure its new
|
// make sure its new
|
||||||
utils.AssertEqual(t, true, sess.Fresh())
|
utils.AssertEqual(t, true, sess.Fresh())
|
||||||
// set value & save
|
// set value & save
|
||||||
|
@ -224,7 +230,8 @@ func Test_Session_Store_Reset(t *testing.T) {
|
||||||
utils.AssertEqual(t, nil, store.Reset())
|
utils.AssertEqual(t, nil, store.Reset())
|
||||||
|
|
||||||
// make sure the session is recreated
|
// make sure the session is recreated
|
||||||
sess, _ = store.Get(ctx)
|
sess, err = store.Get(ctx)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, true, sess.Fresh())
|
utils.AssertEqual(t, true, sess.Fresh())
|
||||||
utils.AssertEqual(t, nil, sess.Get("hello"))
|
utils.AssertEqual(t, nil, sess.Get("hello"))
|
||||||
}
|
}
|
||||||
|
@ -242,12 +249,13 @@ func Test_Session_Save(t *testing.T) {
|
||||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||||
defer app.ReleaseCtx(ctx)
|
defer app.ReleaseCtx(ctx)
|
||||||
// get session
|
// get session
|
||||||
sess, _ := store.Get(ctx)
|
sess, err := store.Get(ctx)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
// set value
|
// set value
|
||||||
sess.Set("name", "john")
|
sess.Set("name", "john")
|
||||||
|
|
||||||
// save session
|
// save session
|
||||||
err := sess.Save()
|
err = sess.Save()
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -262,12 +270,13 @@ func Test_Session_Save(t *testing.T) {
|
||||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||||
defer app.ReleaseCtx(ctx)
|
defer app.ReleaseCtx(ctx)
|
||||||
// get session
|
// get session
|
||||||
sess, _ := store.Get(ctx)
|
sess, err := store.Get(ctx)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
// set value
|
// set value
|
||||||
sess.Set("name", "john")
|
sess.Set("name", "john")
|
||||||
|
|
||||||
// save session
|
// save session
|
||||||
err := sess.Save()
|
err = sess.Save()
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, store.getSessionID(ctx), string(ctx.Response().Header.Peek(store.sessionName)))
|
utils.AssertEqual(t, store.getSessionID(ctx), string(ctx.Response().Header.Peek(store.sessionName)))
|
||||||
utils.AssertEqual(t, store.getSessionID(ctx), string(ctx.Request().Header.Peek(store.sessionName)))
|
utils.AssertEqual(t, store.getSessionID(ctx), string(ctx.Request().Header.Peek(store.sessionName)))
|
||||||
|
@ -287,7 +296,8 @@ func Test_Session_Save_Expiration(t *testing.T) {
|
||||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||||
defer app.ReleaseCtx(ctx)
|
defer app.ReleaseCtx(ctx)
|
||||||
// get session
|
// get session
|
||||||
sess, _ := store.Get(ctx)
|
sess, err := store.Get(ctx)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
// set value
|
// set value
|
||||||
sess.Set("name", "john")
|
sess.Set("name", "john")
|
||||||
|
|
||||||
|
@ -295,18 +305,20 @@ func Test_Session_Save_Expiration(t *testing.T) {
|
||||||
sess.SetExpiry(time.Second * 5)
|
sess.SetExpiry(time.Second * 5)
|
||||||
|
|
||||||
// save session
|
// save session
|
||||||
err := sess.Save()
|
err = sess.Save()
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
// here you need to get the old session yet
|
// here you need to get the old session yet
|
||||||
sess, _ = store.Get(ctx)
|
sess, err = store.Get(ctx)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, "john", sess.Get("name"))
|
utils.AssertEqual(t, "john", sess.Get("name"))
|
||||||
|
|
||||||
// just to make sure the session has been expired
|
// just to make sure the session has been expired
|
||||||
time.Sleep(time.Second * 5)
|
time.Sleep(time.Second * 5)
|
||||||
|
|
||||||
// here you should get a new session
|
// here you should get a new session
|
||||||
sess, _ = store.Get(ctx)
|
sess, err = store.Get(ctx)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, nil, sess.Get("name"))
|
utils.AssertEqual(t, nil, sess.Get("name"))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -325,7 +337,8 @@ func Test_Session_Reset(t *testing.T) {
|
||||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||||
defer app.ReleaseCtx(ctx)
|
defer app.ReleaseCtx(ctx)
|
||||||
// get session
|
// get session
|
||||||
sess, _ := store.Get(ctx)
|
sess, err := store.Get(ctx)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
sess.Set("name", "fenny")
|
sess.Set("name", "fenny")
|
||||||
utils.AssertEqual(t, nil, sess.Destroy())
|
utils.AssertEqual(t, nil, sess.Destroy())
|
||||||
|
@ -345,14 +358,16 @@ func Test_Session_Reset(t *testing.T) {
|
||||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||||
defer app.ReleaseCtx(ctx)
|
defer app.ReleaseCtx(ctx)
|
||||||
// get session
|
// get session
|
||||||
sess, _ := store.Get(ctx)
|
sess, err := store.Get(ctx)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
// set value & save
|
// set value & save
|
||||||
sess.Set("name", "fenny")
|
sess.Set("name", "fenny")
|
||||||
utils.AssertEqual(t, nil, sess.Save())
|
utils.AssertEqual(t, nil, sess.Save())
|
||||||
sess, _ = store.Get(ctx)
|
sess, err = store.Get(ctx)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
err := sess.Destroy()
|
err = sess.Destroy()
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, "", string(ctx.Response().Header.Peek(store.sessionName)))
|
utils.AssertEqual(t, "", string(ctx.Response().Header.Peek(store.sessionName)))
|
||||||
utils.AssertEqual(t, "", string(ctx.Request().Header.Peek(store.sessionName)))
|
utils.AssertEqual(t, "", string(ctx.Request().Header.Peek(store.sessionName)))
|
||||||
|
@ -383,7 +398,8 @@ func Test_Session_Cookie(t *testing.T) {
|
||||||
defer app.ReleaseCtx(ctx)
|
defer app.ReleaseCtx(ctx)
|
||||||
|
|
||||||
// get session
|
// get session
|
||||||
sess, _ := store.Get(ctx)
|
sess, err := store.Get(ctx)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, nil, sess.Save())
|
utils.AssertEqual(t, nil, sess.Save())
|
||||||
|
|
||||||
// cookie should be set on Save ( even if empty data )
|
// cookie should be set on Save ( even if empty data )
|
||||||
|
@ -401,12 +417,14 @@ func Test_Session_Cookie_In_Response(t *testing.T) {
|
||||||
defer app.ReleaseCtx(ctx)
|
defer app.ReleaseCtx(ctx)
|
||||||
|
|
||||||
// get session
|
// get session
|
||||||
sess, _ := store.Get(ctx)
|
sess, err := store.Get(ctx)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
sess.Set("id", "1")
|
sess.Set("id", "1")
|
||||||
utils.AssertEqual(t, true, sess.Fresh())
|
utils.AssertEqual(t, true, sess.Fresh())
|
||||||
utils.AssertEqual(t, nil, sess.Save())
|
utils.AssertEqual(t, nil, sess.Save())
|
||||||
|
|
||||||
sess, _ = store.Get(ctx)
|
sess, err = store.Get(ctx)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
sess.Set("name", "john")
|
sess.Set("name", "john")
|
||||||
utils.AssertEqual(t, true, sess.Fresh())
|
utils.AssertEqual(t, true, sess.Fresh())
|
||||||
|
|
||||||
|
@ -497,7 +515,7 @@ func Benchmark_Session(b *testing.B) {
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
sess, _ := store.Get(c)
|
sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
|
||||||
sess.Set("john", "doe")
|
sess.Set("john", "doe")
|
||||||
err = sess.Save()
|
err = sess.Save()
|
||||||
}
|
}
|
||||||
|
@ -512,7 +530,7 @@ func Benchmark_Session(b *testing.B) {
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
sess, _ := store.Get(c)
|
sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
|
||||||
sess.Set("john", "doe")
|
sess.Set("john", "doe")
|
||||||
err = sess.Save()
|
err = sess.Save()
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,11 +2,13 @@ package session
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/gob"
|
"encoding/gob"
|
||||||
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
||||||
"github.com/gofiber/fiber/v2/utils"
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -14,6 +16,7 @@ type Store struct {
|
||||||
Config
|
Config
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//nolint:gochecknoglobals // TODO: Do not use a global var here
|
||||||
var mux sync.Mutex
|
var mux sync.Mutex
|
||||||
|
|
||||||
func New(config ...Config) *Store {
|
func New(config ...Config) *Store {
|
||||||
|
@ -31,7 +34,7 @@ func New(config ...Config) *Store {
|
||||||
|
|
||||||
// RegisterType will allow you to encode/decode custom types
|
// RegisterType will allow you to encode/decode custom types
|
||||||
// into any Storage provider
|
// into any Storage provider
|
||||||
func (s *Store) RegisterType(i interface{}) {
|
func (*Store) RegisterType(i interface{}) {
|
||||||
gob.Register(i)
|
gob.Register(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,11 +73,11 @@ func (s *Store) Get(c *fiber.Ctx) (*Session, error) {
|
||||||
if raw != nil && err == nil {
|
if raw != nil && err == nil {
|
||||||
mux.Lock()
|
mux.Lock()
|
||||||
defer mux.Unlock()
|
defer mux.Unlock()
|
||||||
_, _ = sess.byteBuffer.Write(raw)
|
_, _ = sess.byteBuffer.Write(raw) //nolint:errcheck // This will never fail
|
||||||
encCache := gob.NewDecoder(sess.byteBuffer)
|
encCache := gob.NewDecoder(sess.byteBuffer)
|
||||||
err := encCache.Decode(&sess.data.Data)
|
err := encCache.Decode(&sess.data.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to decode session data: %w", err)
|
||||||
}
|
}
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/gofiber/fiber/v2/utils"
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package skip
|
package skip
|
||||||
|
|
||||||
import "github.com/gofiber/fiber/v2"
|
import (
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
)
|
||||||
|
|
||||||
// New creates a middleware handler which skips the wrapped handler
|
// New creates a middleware handler which skips the wrapped handler
|
||||||
// if the exclude predicate returns true.
|
// if the exclude predicate returns true.
|
||||||
|
|
|
@ -17,7 +17,7 @@ func Test_Skip(t *testing.T) {
|
||||||
app.Use(skip.New(errTeapotHandler, func(*fiber.Ctx) bool { return true }))
|
app.Use(skip.New(errTeapotHandler, func(*fiber.Ctx) bool { return true }))
|
||||||
app.Get("/", helloWorldHandler)
|
app.Get("/", helloWorldHandler)
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
@ -30,7 +30,7 @@ func Test_SkipFalse(t *testing.T) {
|
||||||
app.Use(skip.New(errTeapotHandler, func(*fiber.Ctx) bool { return false }))
|
app.Use(skip.New(errTeapotHandler, func(*fiber.Ctx) bool { return false }))
|
||||||
app.Get("/", helloWorldHandler)
|
app.Get("/", helloWorldHandler)
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
@ -43,7 +43,7 @@ func Test_SkipNilFunc(t *testing.T) {
|
||||||
app.Use(skip.New(errTeapotHandler, nil))
|
app.Use(skip.New(errTeapotHandler, nil))
|
||||||
app.Get("/", helloWorldHandler)
|
app.Get("/", helloWorldHandler)
|
||||||
|
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
utils.AssertEqual(t, nil, err)
|
utils.AssertEqual(t, nil, err)
|
||||||
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,8 @@ func Test_Timeout(t *testing.T) {
|
||||||
// fiber instance
|
// fiber instance
|
||||||
app := fiber.New()
|
app := fiber.New()
|
||||||
h := New(func(c *fiber.Ctx) error {
|
h := New(func(c *fiber.Ctx) error {
|
||||||
sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms")
|
sleepTime, err := time.ParseDuration(c.Params("sleepTime") + "ms")
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
if err := sleepWithContext(c.UserContext(), sleepTime, context.DeadlineExceeded); err != nil {
|
if err := sleepWithContext(c.UserContext(), sleepTime, context.DeadlineExceeded); err != nil {
|
||||||
return fmt.Errorf("%w: l2 wrap", fmt.Errorf("%w: l1 wrap ", err))
|
return fmt.Errorf("%w: l2 wrap", fmt.Errorf("%w: l1 wrap ", err))
|
||||||
}
|
}
|
||||||
|
@ -26,12 +27,12 @@ func Test_Timeout(t *testing.T) {
|
||||||
}, 100*time.Millisecond)
|
}, 100*time.Millisecond)
|
||||||
app.Get("/test/:sleepTime", h)
|
app.Get("/test/:sleepTime", h)
|
||||||
testTimeout := func(timeoutStr string) {
|
testTimeout := func(timeoutStr string) {
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
|
||||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||||
utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
|
utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
|
||||||
}
|
}
|
||||||
testSucces := func(timeoutStr string) {
|
testSucces := func(timeoutStr string) {
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
|
||||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
||||||
}
|
}
|
||||||
|
@ -49,7 +50,8 @@ func Test_TimeoutWithCustomError(t *testing.T) {
|
||||||
// fiber instance
|
// fiber instance
|
||||||
app := fiber.New()
|
app := fiber.New()
|
||||||
h := New(func(c *fiber.Ctx) error {
|
h := New(func(c *fiber.Ctx) error {
|
||||||
sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms")
|
sleepTime, err := time.ParseDuration(c.Params("sleepTime") + "ms")
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
if err := sleepWithContext(c.UserContext(), sleepTime, ErrFooTimeOut); err != nil {
|
if err := sleepWithContext(c.UserContext(), sleepTime, ErrFooTimeOut); err != nil {
|
||||||
return fmt.Errorf("%w: execution error", err)
|
return fmt.Errorf("%w: execution error", err)
|
||||||
}
|
}
|
||||||
|
@ -57,12 +59,12 @@ func Test_TimeoutWithCustomError(t *testing.T) {
|
||||||
}, 100*time.Millisecond, ErrFooTimeOut)
|
}, 100*time.Millisecond, ErrFooTimeOut)
|
||||||
app.Get("/test/:sleepTime", h)
|
app.Get("/test/:sleepTime", h)
|
||||||
testTimeout := func(timeoutStr string) {
|
testTimeout := func(timeoutStr string) {
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
|
||||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||||
utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
|
utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
|
||||||
}
|
}
|
||||||
testSucces := func(timeoutStr string) {
|
testSucces := func(timeoutStr string) {
|
||||||
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
|
||||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
||||||
}
|
}
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue