Merge branch 'main' into jiejaitt-feature/CSRF-using-Proxy-Middleware

pull/3390/head
JIeJaitt 2025-04-02 17:50:53 +08:00 committed by GitHub
commit 1fdda5dfa6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
88 changed files with 5843 additions and 2176 deletions

4
.github/README.md vendored
View File

@ -124,7 +124,7 @@ We **listen** to our users in [issues](https://github.com/gofiber/fiber/issues),
## ⚠️ Limitations
- Due to Fiber's usage of unsafe, the library may not always be compatible with the latest Go version. Fiber v3 has been tested with Go version 1.23.
- Due to Fiber's usage of unsafe, the library may not always be compatible with the latest Go version. Fiber v3 has been tested with Go version 1.23 or higher.
- Fiber is not compatible with net/http interfaces. This means you will not be able to use projects like gqlgen, go-swagger, or any others which are part of the net/http ecosystem.
## 👀 Examples
@ -708,7 +708,7 @@ List of externally hosted middleware modules and maintained by the [Fiber team](
| :------------------------------------------------ | :-------------------------------------------------------------------------------------------------------------------- |
| [contrib](https://github.com/gofiber/contrib) | Third-party middlewares |
| [storage](https://github.com/gofiber/storage) | Premade storage drivers that implement the Storage interface, designed to be used with various Fiber middlewares. |
| [template](https://github.com/gofiber/template) | This package contains 9 template engines that can be used with Fiber `v3`. Go version 1.23 or higher is required. |
| [template](https://github.com/gofiber/template) | This package contains 9 template engines that can be used with Fiber. |
## 🕶️ Awesome List

9
.github/codecov.yml vendored Normal file
View File

@ -0,0 +1,9 @@
# ignore files or directories to be scanned by codecov
ignore:
- "./docs/"
coverage:
status:
project:
default:
threshold: 0.5%

View File

@ -17,6 +17,7 @@ categories:
- title: '🧹 Updates'
labels:
- '🧹 Updates'
- '⚡️ Performance'
- title: '🐛 Fixes'
labels:
- '☢️ Bug'
@ -48,6 +49,7 @@ version-resolver:
- '☢️ Bug'
- '🤖 Dependencies'
- '🧹 Updates'
- '⚡️ Performance'
default: patch
template: |
$CHANGES

1
.github/release.yml vendored
View File

@ -12,6 +12,7 @@ changelog:
- title: '🧹 Updates'
labels:
- '🧹 Updates'
- '⚡️ Performance'
- title: '🐛 Bug Fixes'
labels:
- '☢️ Bug'

View File

@ -3,11 +3,11 @@ on:
branches:
- master
- main
paths-ignore:
- "**/*.md"
paths:
- "**.go"
pull_request:
paths-ignore:
- "**/*.md"
paths:
- "**.go"
permissions:
# deployments permission to deploy GitHub pages website

View File

@ -37,4 +37,7 @@ jobs:
uses: golangci/golangci-lint-action@v6
with:
# NOTE: Keep this in sync with the version from .golangci.yml
version: v1.62.2
version: v1.64.7
# NOTE(ldez): temporary workaround
install-mode: goinstall

View File

@ -15,7 +15,7 @@ jobs:
unit:
strategy:
matrix:
go-version: [1.23.x]
go-version: [1.23.x, 1.24.x]
platform: [ubuntu-latest, windows-latest, macos-latest, macos-13]
runs-on: ${{ matrix.platform }}
steps:
@ -32,7 +32,7 @@ jobs:
- name: Upload coverage reports to Codecov
if: ${{ matrix.platform == 'ubuntu-latest' && matrix.go-version == '1.23.x' }}
uses: codecov/codecov-action@v5.3.1
uses: codecov/codecov-action@v5.4.0
with:
token: ${{ secrets.CODECOV_TOKEN }}
file: ./coverage.txt

View File

@ -7,7 +7,6 @@ run:
output:
sort-results: true
uniq-by-line: false
linters-settings:
depguard:
@ -187,7 +186,7 @@ linters-settings:
- name: unchecked-type-assertion
disabled: true # TODO: Do not disable
- name: unhandled-error
arguments: ['bytes\.Buffer\.Write']
disabled: true
stylecheck:
checks:
@ -250,7 +249,10 @@ issues:
max-issues-per-linter: 0
max-same-issues: 0
exclude-dirs:
- internal # TODO: Do not ignore interal packages
- internal # TODO: Do not ignore internal packages
exclude-files:
- '_msgp\.go'
- '_msgp_test\.go'
exclude-rules:
- linters:
- err113
@ -263,7 +265,10 @@ issues:
linters:
- bodyclose
- err113
# fix: true
- source: 'fmt.Fprintf?'
linters:
- errcheck
- revive
linters:
enable:
@ -358,7 +363,6 @@ linters:
- stylecheck
# - tagalign # TODO: Enable
- tagliatelle
- tenv
- testableexamples
- testifylint
# - testpackage # TODO: Enable

View File

@ -35,7 +35,7 @@ markdown:
## lint: 🚨 Run lint checks
.PHONY: lint
lint:
go run github.com/golangci/golangci-lint/cmd/golangci-lint@v1.62.2 run ./...
golangci-lint run
## test: 🚦 Execute all tests
.PHONY: test

View File

@ -19,17 +19,47 @@ a jitter is a way to break synchronization across the client and avoid collision
## Signatures
```go
func NewExponentialBackoff(config ...Config) *ExponentialBackoff
func NewExponentialBackoff(config ...retry.Config) *retry.ExponentialBackoff
```
## Examples
Firstly, import the addon from Fiber,
```go
package main
import (
"fmt"
"github.com/gofiber/fiber/v3/addon/retry"
"github.com/gofiber/fiber/v3/client"
)
func main() {
expBackoff := retry.NewExponentialBackoff(retry.Config{})
// Local variables that will be used inside of Retry
var resp *client.Response
var err error
// Retry a network request and return an error to signify to try again
err = expBackoff.Retry(func() error {
client := client.New()
resp, err = client.Get("https://gofiber.io")
if err != nil {
return fmt.Errorf("GET gofiber.io failed: %w", err)
}
if resp.StatusCode() != 200 {
return fmt.Errorf("GET gofiber.io did not return OK 200")
}
return nil
})
// If all retries failed, panic
if err != nil {
panic(err)
}
fmt.Printf("GET gofiber.io succeeded with status code %d\n", resp.StatusCode())
}
```
## Default Config
@ -58,28 +88,23 @@ type Config struct {
//
// Optional. Default: 1 * time.Second
InitialInterval time.Duration
// MaxBackoffTime defines maximum time duration for backoff algorithm. When
// the algorithm is reached this time, rest of the retries will be maximum
// 32 seconds.
//
// Optional. Default: 32 * time.Second
MaxBackoffTime time.Duration
// Multiplier defines multiplier number of the backoff algorithm.
//
// Optional. Default: 2.0
Multiplier float64
// MaxRetryCount defines maximum retry count for the backoff algorithm.
//
// Optional. Default: 10
MaxRetryCount int
// currentInterval tracks the current waiting time.
//
// Optional. Default: 1 * time.Second
currentInterval time.Duration
}
```

117
app.go
View File

@ -106,10 +106,12 @@ type App struct {
tlsHandler *TLSHandler
// Mount fields
mountFields *mountFields
// state management
state *State
// Route stack divided by HTTP methods
stack [][]*Route
// Route stack divided by HTTP methods and route prefixes
treeStack []map[string][]*Route
treeStack []map[int][]*Route
// custom binders
customBinders []CustomBinder
// customConstraints is a list of external constraints
@ -456,17 +458,29 @@ const (
DefaultWriteBufferSize = 4096
)
const (
methodGet = iota
methodHead
methodPost
methodPut
methodDelete
methodConnect
methodOptions
methodTrace
methodPatch
)
// HTTP methods enabled by default
var DefaultMethods = []string{
MethodGet,
MethodHead,
MethodPost,
MethodPut,
MethodDelete,
MethodConnect,
MethodOptions,
MethodTrace,
MethodPatch,
methodGet: MethodGet,
methodHead: MethodHead,
methodPost: MethodPost,
methodPut: MethodPut,
methodDelete: MethodDelete,
methodConnect: MethodConnect,
methodOptions: MethodOptions,
methodTrace: MethodTrace,
methodPatch: MethodPatch,
}
// DefaultErrorHandler that process return errors from handlers
@ -515,6 +529,9 @@ func New(config ...Config) *App {
// Define mountFields
app.mountFields = newMountFields(app)
// Define state
app.state = newState()
// Override config if provided
if len(config) > 0 {
app.config = config[0]
@ -581,7 +598,7 @@ func New(config ...Config) *App {
// Create router stack
app.stack = make([][]*Route, len(app.config.RequestMethods))
app.treeStack = make([]map[string][]*Route, len(app.config.RequestMethods))
app.treeStack = make([]map[int][]*Route, len(app.config.RequestMethods))
// Override colors
app.config.ColorScheme = defaultColors(app.config.ColorScheme)
@ -750,7 +767,7 @@ func (app *App) Use(args ...any) Router {
return app
}
app.register([]string{methodUse}, prefix, nil, nil, handlers...)
app.register([]string{methodUse}, prefix, nil, handlers...)
}
return app
@ -758,67 +775,67 @@ func (app *App) Use(args ...any) Router {
// Get registers a route for GET methods that requests a representation
// of the specified resource. Requests using GET should only retrieve data.
func (app *App) Get(path string, handler Handler, middleware ...Handler) Router {
return app.Add([]string{MethodGet}, path, handler, middleware...)
func (app *App) Get(path string, handler Handler, handlers ...Handler) Router {
return app.Add([]string{MethodGet}, path, handler, handlers...)
}
// Head registers a route for HEAD methods that asks for a response identical
// to that of a GET request, but without the response body.
func (app *App) Head(path string, handler Handler, middleware ...Handler) Router {
return app.Add([]string{MethodHead}, path, handler, middleware...)
func (app *App) Head(path string, handler Handler, handlers ...Handler) Router {
return app.Add([]string{MethodHead}, path, handler, handlers...)
}
// Post registers a route for POST methods that is used to submit an entity to the
// specified resource, often causing a change in state or side effects on the server.
func (app *App) Post(path string, handler Handler, middleware ...Handler) Router {
return app.Add([]string{MethodPost}, path, handler, middleware...)
func (app *App) Post(path string, handler Handler, handlers ...Handler) Router {
return app.Add([]string{MethodPost}, path, handler, handlers...)
}
// Put registers a route for PUT methods that replaces all current representations
// of the target resource with the request payload.
func (app *App) Put(path string, handler Handler, middleware ...Handler) Router {
return app.Add([]string{MethodPut}, path, handler, middleware...)
func (app *App) Put(path string, handler Handler, handlers ...Handler) Router {
return app.Add([]string{MethodPut}, path, handler, handlers...)
}
// Delete registers a route for DELETE methods that deletes the specified resource.
func (app *App) Delete(path string, handler Handler, middleware ...Handler) Router {
return app.Add([]string{MethodDelete}, path, handler, middleware...)
func (app *App) Delete(path string, handler Handler, handlers ...Handler) Router {
return app.Add([]string{MethodDelete}, path, handler, handlers...)
}
// Connect registers a route for CONNECT methods that establishes a tunnel to the
// server identified by the target resource.
func (app *App) Connect(path string, handler Handler, middleware ...Handler) Router {
return app.Add([]string{MethodConnect}, path, handler, middleware...)
func (app *App) Connect(path string, handler Handler, handlers ...Handler) Router {
return app.Add([]string{MethodConnect}, path, handler, handlers...)
}
// Options registers a route for OPTIONS methods that is used to describe the
// communication options for the target resource.
func (app *App) Options(path string, handler Handler, middleware ...Handler) Router {
return app.Add([]string{MethodOptions}, path, handler, middleware...)
func (app *App) Options(path string, handler Handler, handlers ...Handler) Router {
return app.Add([]string{MethodOptions}, path, handler, handlers...)
}
// Trace registers a route for TRACE methods that performs a message loop-back
// test along the path to the target resource.
func (app *App) Trace(path string, handler Handler, middleware ...Handler) Router {
return app.Add([]string{MethodTrace}, path, handler, middleware...)
func (app *App) Trace(path string, handler Handler, handlers ...Handler) Router {
return app.Add([]string{MethodTrace}, path, handler, handlers...)
}
// Patch registers a route for PATCH methods that is used to apply partial
// modifications to a resource.
func (app *App) Patch(path string, handler Handler, middleware ...Handler) Router {
return app.Add([]string{MethodPatch}, path, handler, middleware...)
func (app *App) Patch(path string, handler Handler, handlers ...Handler) Router {
return app.Add([]string{MethodPatch}, path, handler, handlers...)
}
// Add allows you to specify multiple HTTP methods to register a route.
func (app *App) Add(methods []string, path string, handler Handler, middleware ...Handler) Router {
app.register(methods, path, nil, handler, middleware...)
func (app *App) Add(methods []string, path string, handler Handler, handlers ...Handler) Router {
app.register(methods, path, nil, append([]Handler{handler}, handlers...)...)
return app
}
// All will register the handler on all HTTP methods
func (app *App) All(path string, handler Handler, middleware ...Handler) Router {
return app.Add(app.config.RequestMethods, path, handler, middleware...)
func (app *App) All(path string, handler Handler, handlers ...Handler) Router {
return app.Add(app.config.RequestMethods, path, handler, handlers...)
}
// Group is used for Routes with common prefix to define a new sub-router with optional middleware.
@ -828,7 +845,7 @@ func (app *App) All(path string, handler Handler, middleware ...Handler) Router
func (app *App) Group(prefix string, handlers ...Handler) Router {
grp := &Group{Prefix: prefix, app: app}
if len(handlers) > 0 {
app.register([]string{methodUse}, prefix, grp, nil, handlers...)
app.register([]string{methodUse}, prefix, grp, handlers...)
}
if err := app.hooks.executeOnGroupHooks(*grp); err != nil {
panic(err)
@ -894,6 +911,13 @@ func (app *App) HandlersCount() uint32 {
//
// Make sure the program doesn't exit and waits instead for Shutdown to return.
//
// Important: app.Listen() must be called in a separate goroutine, otherwise shutdown hooks will not work
// as Listen() is a blocking operation. Example:
//
// go app.Listen(":3000")
// // ...
// app.Shutdown()
//
// Shutdown does not close keepalive connections so its recommended to set ReadTimeout to something else than 0.
func (app *App) Shutdown() error {
return app.ShutdownWithContext(context.Background())
@ -918,17 +942,21 @@ func (app *App) ShutdownWithTimeout(timeout time.Duration) error {
//
// ShutdownWithContext does not close keepalive connections so its recommended to set ReadTimeout to something else than 0.
func (app *App) ShutdownWithContext(ctx context.Context) error {
if app.hooks != nil {
// TODO: check should be defered?
app.hooks.executeOnShutdownHooks()
}
app.mutex.Lock()
defer app.mutex.Unlock()
var err error
if app.server == nil {
return ErrNotRunning
}
return app.server.ShutdownWithContext(ctx)
// Execute the Shutdown hook
app.hooks.executeOnPreShutdownHooks()
defer app.hooks.executeOnPostShutdownHooks(err)
err = app.server.ShutdownWithContext(ctx)
return err
}
// Server returns the underlying fasthttp server
@ -941,6 +969,11 @@ func (app *App) Hooks() *Hooks {
return app.hooks
}
// State returns the state struct to store global data in order to share it between handlers.
func (app *App) State() *State {
return app.state
}
var ErrTestGotEmptyResponse = errors.New("test: got empty response")
// TestConfig is a struct holding Test settings
@ -1013,7 +1046,7 @@ func (app *App) Test(req *http.Request, config ...TestConfig) (*http.Response, e
select {
case err = <-channel:
case <-time.After(cfg.Timeout):
conn.Close() //nolint:errcheck, revive // It is fine to ignore the error here
conn.Close() //nolint:errcheck // It is fine to ignore the error here
if cfg.FailOnTimeout {
return nil, os.ErrDeadlineExceeded
}

View File

@ -21,6 +21,7 @@ import (
"regexp"
"runtime"
"strings"
"sync"
"testing"
"time"
@ -402,7 +403,7 @@ func Test_App_serverErrorHandler_Internal_Error(t *testing.T) {
t.Parallel()
app := New()
msg := "test err"
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
app.serverErrorHandler(c.fasthttp, errors.New(msg))
require.Equal(t, string(c.fasthttp.Response.Body()), msg)
@ -412,7 +413,7 @@ func Test_App_serverErrorHandler_Internal_Error(t *testing.T) {
func Test_App_serverErrorHandler_Network_Error(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
app.serverErrorHandler(c.fasthttp, &net.DNSError{
Err: "test error",
@ -479,14 +480,10 @@ func Test_App_Use_Params(t *testing.T) {
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
defer func() {
if err := recover(); err != nil {
require.Equal(t, "use: invalid handler func()\n", fmt.Sprintf("%v", err))
}
}()
app.Use("/:param/*", func() {
// this should panic
require.PanicsWithValue(t, "use: invalid handler func()\n", func() {
app.Use("/:param/*", func() {
// this should panic
})
})
}
@ -930,20 +927,29 @@ func Test_App_ShutdownWithTimeout(t *testing.T) {
})
ln := fasthttputil.NewInmemoryListener()
serverReady := make(chan struct{}) // Signal that the server is ready to start
go func() {
serverReady <- struct{}{}
err := app.Listener(ln)
assert.NoError(t, err)
}()
time.Sleep(1 * time.Second)
<-serverReady // Waiting for the server to be ready
// Create a connection and send a request
connReady := make(chan struct{})
go func() {
conn, err := ln.Dial()
assert.NoError(t, err)
_, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n"))
assert.NoError(t, err)
connReady <- struct{}{} // Signal that the request has been sent
}()
time.Sleep(1 * time.Second)
<-connReady // Waiting for the request to be sent
shutdownErr := make(chan error)
go func() {
@ -964,46 +970,130 @@ func Test_App_ShutdownWithTimeout(t *testing.T) {
func Test_App_ShutdownWithContext(t *testing.T) {
t.Parallel()
app := New()
app.Get("/", func(ctx Ctx) error {
time.Sleep(5 * time.Second)
return ctx.SendString("body")
t.Run("successful shutdown", func(t *testing.T) {
t.Parallel()
app := New()
// Fast request that should complete
app.Get("/", func(c Ctx) error {
return c.SendString("OK")
})
ln := fasthttputil.NewInmemoryListener()
serverStarted := make(chan bool, 1)
go func() {
serverStarted <- true
if err := app.Listener(ln); err != nil {
t.Errorf("Failed to start listener: %v", err)
}
}()
<-serverStarted
// Execute normal request
conn, err := ln.Dial()
require.NoError(t, err)
_, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"))
require.NoError(t, err)
// Shutdown with sufficient timeout
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err = app.ShutdownWithContext(ctx)
require.NoError(t, err, "Expected successful shutdown")
})
ln := fasthttputil.NewInmemoryListener()
t.Run("shutdown with hooks", func(t *testing.T) {
t.Parallel()
app := New()
go func() {
err := app.Listener(ln)
assert.NoError(t, err)
}()
hookOrder := make([]string, 0)
var hookMutex sync.Mutex
time.Sleep(1 * time.Second)
app.Hooks().OnPreShutdown(func() error {
hookMutex.Lock()
hookOrder = append(hookOrder, "pre")
hookMutex.Unlock()
return nil
})
go func() {
conn, err := ln.Dial()
assert.NoError(t, err)
app.Hooks().OnPostShutdown(func(_ error) error {
hookMutex.Lock()
hookOrder = append(hookOrder, "post")
hookMutex.Unlock()
return nil
})
_, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n"))
assert.NoError(t, err)
}()
ln := fasthttputil.NewInmemoryListener()
go func() {
if err := app.Listener(ln); err != nil {
t.Errorf("Failed to start listener: %v", err)
}
}()
time.Sleep(1 * time.Second)
time.Sleep(100 * time.Millisecond)
shutdownErr := make(chan error)
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
shutdownErr <- app.ShutdownWithContext(ctx)
}()
err := app.ShutdownWithContext(context.Background())
require.NoError(t, err)
select {
case <-time.After(5 * time.Second):
t.Fatal("idle connections not closed on shutdown")
case err := <-shutdownErr:
if err == nil || !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("unexpected err %v. Expecting %v", err, context.DeadlineExceeded)
require.Equal(t, []string{"pre", "post"}, hookOrder, "Hooks should execute in order")
})
t.Run("timeout with long running request", func(t *testing.T) {
t.Parallel()
app := New()
requestStarted := make(chan struct{})
requestProcessing := make(chan struct{})
app.Get("/", func(c Ctx) error {
close(requestStarted)
// Wait for signal to continue processing the request
<-requestProcessing
time.Sleep(2 * time.Second)
return c.SendString("OK")
})
ln := fasthttputil.NewInmemoryListener()
go func() {
if err := app.Listener(ln); err != nil {
t.Errorf("Failed to start listener: %v", err)
}
}()
// Ensure server is fully started
time.Sleep(100 * time.Millisecond)
// Start a long-running request
go func() {
conn, err := ln.Dial()
if err != nil {
t.Errorf("Failed to dial: %v", err)
return
}
if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")); err != nil {
t.Errorf("Failed to write: %v", err)
}
}()
// Wait for request to start
select {
case <-requestStarted:
// Request has started, signal to continue processing
close(requestProcessing)
case <-time.After(2 * time.Second):
t.Fatal("Request did not start in time")
}
}
// Attempt shutdown, should timeout
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
err := app.ShutdownWithContext(ctx)
require.ErrorIs(t, err, context.DeadlineExceeded)
})
}
// go test -run Test_App_Mixed_Routes_WithSameLen
@ -1055,12 +1145,10 @@ func Test_App_Mixed_Routes_WithSameLen(t *testing.T) {
func Test_App_Group_Invalid(t *testing.T) {
t.Parallel()
defer func() {
if err := recover(); err != nil {
require.Equal(t, "use: invalid handler int\n", fmt.Sprintf("%v", err))
}
}()
New().Group("/").Use(1)
require.PanicsWithValue(t, "use: invalid handler int\n", func() {
New().Group("/").Use(1)
})
}
func Test_App_Group(t *testing.T) {
@ -1285,14 +1373,10 @@ func Test_App_Init_Error_View(t *testing.T) {
t.Parallel()
app := New(Config{Views: invalidView{}})
defer func() {
if err := recover(); err != nil {
require.Equal(t, "implement me", fmt.Sprintf("%v", err))
}
}()
err := app.config.Views.Render(nil, "", nil)
require.NoError(t, err)
require.PanicsWithValue(t, "implement me", func() {
//nolint:errcheck // not needed
_ = app.config.Views.Render(nil, "", nil)
})
}
// go test -run Test_App_Stack
@ -1806,6 +1890,16 @@ func Test_Route_Naming_Issue_2671_2685(t *testing.T) {
require.Equal(t, "/simple-route", sRoute2.Path)
}
func Test_App_State(t *testing.T) {
t.Parallel()
app := New()
app.State().Set("key", "value")
str, ok := app.State().GetString("key")
require.True(t, ok)
require.Equal(t, "value", str)
}
// go test -v -run=^$ -bench=Benchmark_Communication_Flow -benchmem -count=4
func Benchmark_Communication_Flow(b *testing.B) {
app := New()

View File

@ -1378,7 +1378,7 @@ func Benchmark_Bind_URI(b *testing.B) {
var err error
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
c.route = &Route{
Params: []string{
@ -1415,7 +1415,7 @@ func Benchmark_Bind_URI_Map(b *testing.B) {
var err error
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
c.route = &Route{
Params: []string{

View File

@ -28,7 +28,7 @@ Fiber provides several default binders out of the box:
### Binding into a Struct
Fiber supports binding request data directly into a struct using [gorilla/schema](https://github.com/gorilla/schema). Here's an example:
Fiber supports binding request data directly into a struct using [gofiber/schema](https://github.com/gofiber/schema). Here's an example:
```go
// Field names must start with an uppercase letter

View File

@ -1,6 +1,8 @@
package binder
import (
"mime/multipart"
"github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp"
)
@ -59,7 +61,15 @@ func (b *FormBinding) bindMultipart(req *fasthttp.Request, out any) error {
}
}
return parse(b.Name(), out, data)
files := make(map[string][]*multipart.FileHeader)
for key, values := range multipartForm.File {
err = formatBindData(out, files, key, values, b.EnableSplitting, true)
if err != nil {
return err
}
}
return parse(b.Name(), out, data, files)
}
// Reset resets the FormBinding binder.

View File

@ -2,6 +2,7 @@ package binder
import (
"bytes"
"io"
"mime/multipart"
"testing"
@ -57,19 +58,19 @@ func Benchmark_FormBinder_Bind(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
binder := &QueryBinding{
binder := &FormBinding{
EnableSplitting: true,
}
type User struct {
Name string `query:"name"`
Posts []string `query:"posts"`
Age int `query:"age"`
Name string `form:"name"`
Posts []string `form:"posts"`
Age int `form:"age"`
}
var user User
req := fasthttp.AcquireRequest()
req.URI().SetQueryString("name=john&age=42&posts=post1,post2,post3")
req.SetBodyString("name=john&age=42&posts=post1,post2,post3")
req.Header.SetContentType("application/x-www-form-urlencoded")
b.ResetTimer()
@ -98,10 +99,12 @@ func Test_FormBinder_BindMultipart(t *testing.T) {
}
type User struct {
Name string `form:"name"`
Names []string `form:"names"`
Posts []Post `form:"posts"`
Age int `form:"age"`
Avatar *multipart.FileHeader `form:"avatar"`
Name string `form:"name"`
Names []string `form:"names"`
Posts []Post `form:"posts"`
Avatars []*multipart.FileHeader `form:"avatars"`
Age int `form:"age"`
}
var user User
@ -118,6 +121,24 @@ func Test_FormBinder_BindMultipart(t *testing.T) {
require.NoError(t, mw.WriteField("posts[1][title]", "post2"))
require.NoError(t, mw.WriteField("posts[2][title]", "post3"))
writer, err := mw.CreateFormFile("avatar", "avatar.txt")
require.NoError(t, err)
_, err = writer.Write([]byte("avatar"))
require.NoError(t, err)
writer, err = mw.CreateFormFile("avatars", "avatar1.txt")
require.NoError(t, err)
_, err = writer.Write([]byte("avatar1"))
require.NoError(t, err)
writer, err = mw.CreateFormFile("avatars", "avatar2.txt")
require.NoError(t, err)
_, err = writer.Write([]byte("avatar2"))
require.NoError(t, err)
require.NoError(t, mw.Close())
req.Header.SetContentType(mw.FormDataContentType())
@ -127,7 +148,7 @@ func Test_FormBinder_BindMultipart(t *testing.T) {
fasthttp.ReleaseRequest(req)
})
err := b.Bind(req, &user)
err = b.Bind(req, &user)
require.NoError(t, err)
require.Equal(t, "john", user.Name)
@ -139,6 +160,38 @@ func Test_FormBinder_BindMultipart(t *testing.T) {
require.Equal(t, "post1", user.Posts[0].Title)
require.Equal(t, "post2", user.Posts[1].Title)
require.Equal(t, "post3", user.Posts[2].Title)
require.NotNil(t, user.Avatar)
require.Equal(t, "avatar.txt", user.Avatar.Filename)
require.Equal(t, "application/octet-stream", user.Avatar.Header.Get("Content-Type"))
file, err := user.Avatar.Open()
require.NoError(t, err)
content, err := io.ReadAll(file)
require.NoError(t, err)
require.Equal(t, "avatar", string(content))
require.Len(t, user.Avatars, 2)
require.Equal(t, "avatar1.txt", user.Avatars[0].Filename)
require.Equal(t, "application/octet-stream", user.Avatars[0].Header.Get("Content-Type"))
file, err = user.Avatars[0].Open()
require.NoError(t, err)
content, err = io.ReadAll(file)
require.NoError(t, err)
require.Equal(t, "avatar1", string(content))
require.Equal(t, "avatar2.txt", user.Avatars[1].Filename)
require.Equal(t, "application/octet-stream", user.Avatars[1].Header.Get("Content-Type"))
file, err = user.Avatars[1].Open()
require.NoError(t, err)
content, err = io.ReadAll(file)
require.NoError(t, err)
require.Equal(t, "avatar2", string(content))
}
func Benchmark_FormBinder_BindMultipart(b *testing.B) {

View File

@ -3,6 +3,7 @@ package binder
import (
"errors"
"fmt"
"mime/multipart"
"reflect"
"strings"
"sync"
@ -69,7 +70,7 @@ func init() {
}
// parse data into the map or struct
func parse(aliasTag string, out any, data map[string][]string) error {
func parse(aliasTag string, out any, data map[string][]string, files ...map[string][]*multipart.FileHeader) error {
ptrVal := reflect.ValueOf(out)
// Get pointer value
@ -83,11 +84,11 @@ func parse(aliasTag string, out any, data map[string][]string) error {
}
// Parse into the struct
return parseToStruct(aliasTag, out, data)
return parseToStruct(aliasTag, out, data, files...)
}
// Parse data into the struct with gorilla/schema
func parseToStruct(aliasTag string, out any, data map[string][]string) error {
// Parse data into the struct with gofiber/schema
func parseToStruct(aliasTag string, out any, data map[string][]string, files ...map[string][]*multipart.FileHeader) error {
// Get decoder from pool
schemaDecoder := decoderPoolMap[aliasTag].Get().(*schema.Decoder) //nolint:errcheck,forcetypeassert // not needed
defer decoderPoolMap[aliasTag].Put(schemaDecoder)
@ -95,7 +96,7 @@ func parseToStruct(aliasTag string, out any, data map[string][]string) error {
// Set alias tag
schemaDecoder.SetAliasTag(aliasTag)
if err := schemaDecoder.Decode(out, data); err != nil {
if err := schemaDecoder.Decode(out, data, files...); err != nil {
return fmt.Errorf("bind: %w", err)
}
@ -250,7 +251,7 @@ func FilterFlags(content string) string {
return content
}
func formatBindData[T any](out any, data map[string][]string, key string, value T, enableSplitting, supportBracketNotation bool) error { //nolint:revive // it's okay
func formatBindData[T, K any](out any, data map[string][]T, key string, value K, enableSplitting, supportBracketNotation bool) error { //nolint:revive // it's okay
var err error
if supportBracketNotation && strings.Contains(key, "[") {
key, err = parseParamSquareBrackets(key)
@ -261,10 +262,28 @@ func formatBindData[T any](out any, data map[string][]string, key string, value
switch v := any(value).(type) {
case string:
assignBindData(out, data, key, v, enableSplitting)
dataMap, ok := any(data).(map[string][]string)
if !ok {
return fmt.Errorf("unsupported value type: %T", value)
}
assignBindData(out, dataMap, key, v, enableSplitting)
case []string:
dataMap, ok := any(data).(map[string][]string)
if !ok {
return fmt.Errorf("unsupported value type: %T", value)
}
for _, val := range v {
assignBindData(out, data, key, val, enableSplitting)
assignBindData(out, dataMap, key, val, enableSplitting)
}
case []*multipart.FileHeader:
for _, val := range v {
valT, ok := any(val).(T)
if !ok {
return fmt.Errorf("unsupported value type: %T", value)
}
data[key] = append(data[key], valT)
}
default:
return fmt.Errorf("unsupported value type: %T", value)

View File

@ -2,6 +2,7 @@ package binder
import (
"errors"
"mime/multipart"
"reflect"
"testing"
@ -9,6 +10,8 @@ import (
)
func Test_EqualFieldType(t *testing.T) {
t.Parallel()
var out int
require.False(t, equalFieldType(&out, reflect.Int, "key"))
@ -47,6 +50,8 @@ func Test_EqualFieldType(t *testing.T) {
}
func Test_ParseParamSquareBrackets(t *testing.T) {
t.Parallel()
tests := []struct {
err error
input string
@ -101,6 +106,8 @@ func Test_ParseParamSquareBrackets(t *testing.T) {
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
t.Parallel()
result, err := parseParamSquareBrackets(tt.input)
if tt.err != nil {
require.Error(t, err)
@ -114,6 +121,8 @@ func Test_ParseParamSquareBrackets(t *testing.T) {
}
func Test_parseToMap(t *testing.T) {
t.Parallel()
inputMap := map[string][]string{
"key1": {"value1", "value2"},
"key2": {"value3"},
@ -147,6 +156,8 @@ func Test_parseToMap(t *testing.T) {
}
func Test_FilterFlags(t *testing.T) {
t.Parallel()
tests := []struct {
input string
expected string
@ -172,8 +183,163 @@ func Test_FilterFlags(t *testing.T) {
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
t.Parallel()
result := FilterFlags(tt.input)
require.Equal(t, tt.expected, result)
})
}
}
func TestFormatBindData(t *testing.T) {
t.Parallel()
t.Run("string value with valid key", func(t *testing.T) {
t.Parallel()
out := struct{}{}
data := make(map[string][]string)
err := formatBindData(out, data, "name", "John", false, false)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(data["name"]) != 1 || data["name"][0] != "John" {
t.Fatalf("expected data[\"name\"] = [John], got %v", data["name"])
}
})
t.Run("unsupported value type", func(t *testing.T) {
t.Parallel()
out := struct{}{}
data := make(map[string][]string)
err := formatBindData(out, data, "age", 30, false, false) // int is unsupported
if err == nil {
t.Fatal("expected an error, got nil")
}
})
t.Run("bracket notation parsing error", func(t *testing.T) {
t.Parallel()
out := struct{}{}
data := make(map[string][]string)
err := formatBindData(out, data, "invalid[", "value", false, true) // malformed bracket notation
if err == nil {
t.Fatal("expected an error, got nil")
}
})
t.Run("handling multipart file headers", func(t *testing.T) {
t.Parallel()
out := struct{}{}
data := make(map[string][]*multipart.FileHeader)
files := []*multipart.FileHeader{
{Filename: "file1.txt"},
{Filename: "file2.txt"},
}
err := formatBindData(out, data, "files", files, false, false)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(data["files"]) != 2 {
t.Fatalf("expected 2 files, got %d", len(data["files"]))
}
})
t.Run("type casting error", func(t *testing.T) {
t.Parallel()
out := struct{}{}
data := map[string][]int{} // Incorrect type to force a casting error
err := formatBindData(out, data, "key", "value", false, false)
require.Equal(t, "unsupported value type: string", err.Error())
})
}
func TestAssignBindData(t *testing.T) {
t.Parallel()
t.Run("splitting enabled with comma", func(t *testing.T) {
t.Parallel()
out := struct {
Colors []string `query:"colors"`
}{}
data := make(map[string][]string)
assignBindData(&out, data, "colors", "red,blue,green", true)
require.Len(t, data["colors"], 3)
})
t.Run("splitting disabled", func(t *testing.T) {
t.Parallel()
var out []string
data := make(map[string][]string)
assignBindData(out, data, "color", "red,blue", false)
require.Len(t, data["color"], 1)
})
}
func Test_parseToStruct_MismatchedData(t *testing.T) {
t.Parallel()
type User struct {
Name string `query:"name"`
Age int `query:"age"`
}
data := map[string][]string{
"name": {"John"},
"age": {"invalidAge"},
}
err := parseToStruct("query", &User{}, data)
require.Error(t, err)
require.EqualError(t, err, "bind: schema: error converting value for \"age\"")
}
func Test_formatBindData_ErrorCases(t *testing.T) {
t.Parallel()
t.Run("unsupported value type int", func(t *testing.T) {
t.Parallel()
out := struct{}{}
data := make(map[string][]string)
err := formatBindData(out, data, "age", 30, false, false) // int is unsupported
require.Error(t, err)
require.EqualError(t, err, "unsupported value type: int")
})
t.Run("unsupported value type map", func(t *testing.T) {
t.Parallel()
out := struct{}{}
data := make(map[string][]string)
err := formatBindData(out, data, "map", map[string]string{"key": "value"}, false, false) // map is unsupported
require.Error(t, err)
require.EqualError(t, err, "unsupported value type: map[string]string")
})
t.Run("bracket notation parsing error", func(t *testing.T) {
t.Parallel()
out := struct{}{}
data := make(map[string][]string)
err := formatBindData(out, data, "invalid[", "value", false, true) // malformed bracket notation
require.Error(t, err)
require.EqualError(t, err, "unmatched brackets")
})
t.Run("type casting error for []string", func(t *testing.T) {
t.Parallel()
out := struct{}{}
data := make(map[string][]string)
err := formatBindData(out, data, "names", 123, false, false) // invalid type for []string
require.Error(t, err)
require.EqualError(t, err, "unsupported value type: int")
})
}

View File

@ -131,7 +131,7 @@ func testRequestFail(t *testing.T, handler fiber.Handler, wrapAgent func(agent *
}
}
func testClient(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Client), excepted string, count ...int) { //nolint: unparam // maybe needed
func testClient(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Client), excepted string, count ...int) { //nolint:unparam // maybe needed
t.Helper()
app, ln, start := createHelperServer(t)

View File

@ -3,24 +3,22 @@ package client
import (
"fmt"
"io"
"math/rand"
"math/rand/v2"
"mime/multipart"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"time"
"github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp"
)
var (
protocolCheck = regexp.MustCompile(`^https?://.*$`)
headerAccept = "Accept"
var protocolCheck = regexp.MustCompile(`^https?://.*$`)
const (
headerAccept = "Accept"
applicationJSON = "application/json"
applicationCBOR = "application/cbor"
applicationXML = "application/xml"
@ -30,25 +28,26 @@ var (
letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
letterIdxBits = 6 // 6 bits to represent a letter index
letterIdxMask = 1<<letterIdxBits - 1 // All 1-bits, as many as letterIdxBits
letterIdxMax = 63 / letterIdxBits // # of letter indices fitting into 63 bits
letterIdxMax = 64 / letterIdxBits // # of letter indices fitting into 64 bits
)
// randString returns a random string of length n.
func randString(n int) string {
// unsafeRandString returns a random string of length n.
func unsafeRandString(n int) string {
b := make([]byte, n)
length := len(letterBytes)
src := rand.NewSource(time.Now().UnixNano())
const length = uint64(len(letterBytes))
for i, cache, remain := n-1, src.Int63(), letterIdxMax; i >= 0; {
//nolint:gosec // Not a concern
for i, cache, remain := n-1, rand.Uint64(), letterIdxMax; i >= 0; {
if remain == 0 {
cache, remain = src.Int63(), letterIdxMax
//nolint:gosec // Not a concern
cache, remain = rand.Uint64(), letterIdxMax
}
if idx := int(cache & int64(letterIdxMask)); idx < length {
if idx := cache & letterIdxMask; idx < length {
b[i] = letterBytes[idx]
i--
}
cache >>= int64(letterIdxBits)
cache >>= letterIdxBits
remain--
}
@ -134,7 +133,7 @@ func parserRequestHeader(c *Client, req *Request) error {
req.RawRequest.Header.SetContentType(multipartFormData)
// If boundary is default, append a random string to it.
if req.boundary == boundary {
req.boundary += randString(16)
req.boundary += unsafeRandString(16)
}
req.RawRequest.Header.SetMultipartFormBoundary(req.boundary)
default:
@ -200,7 +199,7 @@ func parserRequestBody(c *Client, req *Request) error {
case filesBody:
return parserRequestBodyFile(req)
case rawBody:
if body, ok := req.body.([]byte); ok {
if body, ok := req.body.([]byte); ok { //nolint:revive // ignore simplicity
req.RawRequest.SetBody(body)
} else {
return ErrBodyType

View File

@ -38,7 +38,7 @@ func Test_Rand_String(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := randString(tt.args)
got := unsafeRandString(tt.args)
require.Len(t, got, tt.args)
})
}

View File

@ -298,11 +298,12 @@ func (r *Request) Cookie(key string) string {
// Use maps.Collect() to gather them into a map if needed.
func (r *Request) Cookies() iter.Seq2[string, string] {
return func(yield func(string, string) bool) {
r.cookies.VisitAll(func(key, val string) {
if !yield(key, val) {
for k, v := range *r.cookies {
res := yield(k, v)
if !res {
return
}
})
}
}
}
@ -343,11 +344,11 @@ func (r *Request) PathParam(key string) string {
// Use maps.Collect() to gather them into a map if needed.
func (r *Request) PathParams() iter.Seq2[string, string] {
return func(yield func(string, string) bool) {
r.path.VisitAll(func(key, val string) {
if !yield(key, val) {
for k, v := range *r.path {
if !yield(k, v) {
return
}
})
}
}
}

View File

@ -1,3 +1,4 @@
//nolint:goconst // Much easier to just ignore memory leaks in tests
package client
import (
@ -451,6 +452,14 @@ func Test_Request_Cookies(t *testing.T) {
require.Equal(t, "bar", cookies["foo"])
require.Equal(t, "foo", cookies["bar"])
require.NotPanics(t, func() {
for _, v := range req.Cookies() {
if v == "bar" {
break
}
}
})
require.Len(t, cookies, 2)
}
@ -564,6 +573,14 @@ func Test_Request_PathParams(t *testing.T) {
require.Equal(t, "foo", pathParams["bar"])
require.Len(t, pathParams, 2)
require.NotPanics(t, func() {
for _, v := range req.PathParams() {
if v == "bar" {
break
}
}
})
}
func Benchmark_Request_PathParams(b *testing.B) {
@ -1579,7 +1596,7 @@ func Test_SetValWithStruct(t *testing.T) {
require.True(t, func() bool {
for _, v := range p.PeekMulti("TSlice") {
if string(v) == "bar" { //nolint:goconst // test
if string(v) == "bar" {
return true
}
}

116
ctx.go
View File

@ -33,8 +33,11 @@ const (
schemeHTTPS = "https"
)
// maxParams defines the maximum number of parameters per route.
const maxParams = 30
const (
// maxParams defines the maximum number of parameters per route.
maxParams = 30
maxDetectionPaths = 3
)
// The contextKey type is unexported to prevent collisions with context keys defined in
// other packages.
@ -49,26 +52,25 @@ const userContextKey contextKey = 0 // __local_user_context__
//
//go:generate ifacemaker --file ctx.go --struct DefaultCtx --iface Ctx --pkg fiber --output ctx_interface_gen.go --not-exported true --iface-comment "Ctx represents the Context which hold the HTTP request and response.\nIt has methods for the request query string, parameters, body, HTTP headers and so on."
type DefaultCtx struct {
app *App // Reference to *App
route *Route // Reference to *Route
fasthttp *fasthttp.RequestCtx // Reference to *fasthttp.RequestCtx
bind *Bind // Default bind reference
redirect *Redirect // Default redirect reference
values [maxParams]string // Route parameter values
viewBindMap sync.Map // Default view map to bind template engine
method string // HTTP method
baseURI string // HTTP base uri
path string // HTTP path with the modifications by the configuration -> string copy from pathBuffer
detectionPath string // Route detection path -> string copy from detectionPathBuffer
treePath string // Path for the search in the tree
pathOriginal string // Original HTTP path
pathBuffer []byte // HTTP path buffer
detectionPathBuffer []byte // HTTP detectionPath buffer
flashMessages redirectionMsgs // Flash messages
indexRoute int // Index of the current route
indexHandler int // Index of the current handler
methodINT int // HTTP method INT equivalent
matched bool // Non use route matched
app *App // Reference to *App
route *Route // Reference to *Route
fasthttp *fasthttp.RequestCtx // Reference to *fasthttp.RequestCtx
bind *Bind // Default bind reference
redirect *Redirect // Default redirect reference
req *DefaultReq // Default request api reference
res *DefaultRes // Default response api reference
values [maxParams]string // Route parameter values
viewBindMap sync.Map // Default view map to bind template engine
baseURI string // HTTP base uri
pathOriginal string // Original HTTP path
flashMessages redirectionMsgs // Flash messages
path []byte // HTTP path with the modifications by the configuration
detectionPath []byte // Route detection path
treePathHash int // Hash of the path for the search in the tree
indexRoute int // Index of the current route
indexHandler int // Index of the current handler
methodInt int // HTTP method INT equivalent
matched bool // Non use route matched
}
// SendFile defines configuration options when to transfer file with SendFile.
@ -1003,19 +1005,17 @@ func (c *DefaultCtx) Location(path string) {
func (c *DefaultCtx) Method(override ...string) string {
if len(override) == 0 {
// Nothing to override, just return current method from context
return c.method
return c.app.method(c.methodInt)
}
method := utils.ToUpper(override[0])
mINT := c.app.methodInt(method)
if mINT == -1 {
methodInt := c.app.methodInt(method)
if methodInt == -1 {
// Provided override does not valid HTTP method, no override, return current method
return c.method
return c.app.method(c.methodInt)
}
c.method = method
c.methodINT = mINT
return c.method
c.methodInt = methodInt
return method
}
// MultipartForm parse form entries from binary.
@ -1121,8 +1121,9 @@ func Params[V GenericType](c Ctx, key string, defaultValue ...V) V {
// Path returns the path part of the request URL.
// Optionally, you could override the path.
// Make copies or use the Immutable setting to use the value outside the Handler.
func (c *DefaultCtx) Path(override ...string) string {
if len(override) != 0 && c.path != override[0] {
if len(override) != 0 && string(c.path) != override[0] {
// Set new path to context
c.pathOriginal = override[0]
@ -1131,7 +1132,7 @@ func (c *DefaultCtx) Path(override ...string) string {
// Prettify path
c.configDependentPaths()
}
return c.path
return c.app.getString(c.path)
}
// Scheme contains the request protocol string: http or https for TLS requests.
@ -1347,7 +1348,7 @@ func (c *DefaultCtx) getLocationFromRoute(route Route, params Map) (string, erro
for key, val := range params {
isSame := key == segment.ParamName || (!c.app.config.CaseSensitive && utils.EqualFold(key, segment.ParamName))
isGreedy := segment.IsGreedy && len(key) == 1 && isInCharset(key[0], greedyParameters)
isGreedy := segment.IsGreedy && len(key) == 1 && bytes.IndexByte(greedyParameters, key[0]) != -1
if isSame || isGreedy {
_, err := buf.WriteString(utils.ToString(val))
if err != nil {
@ -1463,6 +1464,18 @@ func (c *DefaultCtx) renderExtensions(bind any) {
}
}
// Req returns a convenience type whose API is limited to operations
// on the incoming request.
func (c *DefaultCtx) Req() Req {
return c.req
}
// Res returns a convenience type whose API is limited to operations
// on the outgoing response.
func (c *DefaultCtx) Res() Res {
return c.res
}
// Route returns the matched Route struct.
func (c *DefaultCtx) Route() *Route {
if c.route == nil {
@ -1470,7 +1483,7 @@ func (c *DefaultCtx) Route() *Route {
return &Route{
path: c.pathOriginal,
Path: c.pathOriginal,
Method: c.method,
Method: c.Method(),
Handlers: make([]Handler, 0),
Params: make([]string, 0),
}
@ -1555,6 +1568,7 @@ func (c *DefaultCtx) SendFile(file string, config ...SendFile) error {
AcceptByteRange: cfg.ByteRange,
Compress: cfg.Compress,
CompressBrotli: cfg.Compress,
CompressZstd: cfg.Compress,
CompressedFileSuffixes: c.app.config.CompressedFileSuffixes,
CacheDuration: cfg.CacheDuration,
SkipCache: cfg.CacheDuration < 0,
@ -1817,32 +1831,31 @@ func (c *DefaultCtx) XHR() bool {
// configDependentPaths set paths for route recognition and prepared paths for the user,
// here the features for caseSensitive, decoded paths, strict paths are evaluated
func (c *DefaultCtx) configDependentPaths() {
c.pathBuffer = append(c.pathBuffer[0:0], c.pathOriginal...)
c.path = append(c.path[:0], c.pathOriginal...)
// If UnescapePath enabled, we decode the path and save it for the framework user
if c.app.config.UnescapePath {
c.pathBuffer = fasthttp.AppendUnquotedArg(c.pathBuffer[:0], c.pathBuffer)
c.path = fasthttp.AppendUnquotedArg(c.path[:0], c.path)
}
c.path = c.app.getString(c.pathBuffer)
// another path is specified which is for routing recognition only
// use the path that was changed by the previous configuration flags
c.detectionPathBuffer = append(c.detectionPathBuffer[0:0], c.pathBuffer...)
c.detectionPath = append(c.detectionPath[:0], c.path...)
// If CaseSensitive is disabled, we lowercase the original path
if !c.app.config.CaseSensitive {
c.detectionPathBuffer = utils.ToLowerBytes(c.detectionPathBuffer)
c.detectionPath = utils.ToLowerBytes(c.detectionPath)
}
// If StrictRouting is disabled, we strip all trailing slashes
if !c.app.config.StrictRouting && len(c.detectionPathBuffer) > 1 && c.detectionPathBuffer[len(c.detectionPathBuffer)-1] == '/' {
c.detectionPathBuffer = utils.TrimRight(c.detectionPathBuffer, '/')
if !c.app.config.StrictRouting && len(c.detectionPath) > 1 && c.detectionPath[len(c.detectionPath)-1] == '/' {
c.detectionPath = utils.TrimRight(c.detectionPath, '/')
}
c.detectionPath = c.app.getString(c.detectionPathBuffer)
// Define the path for dividing routes into areas for fast tree detection, so that fewer routes need to be traversed,
// since the first three characters area select a list of routes
c.treePath = c.treePath[0:0]
const maxDetectionPaths = 3
c.treePathHash = 0
if len(c.detectionPath) >= maxDetectionPaths {
c.treePath = c.detectionPath[:maxDetectionPaths]
c.treePathHash = int(c.detectionPath[0])<<16 |
int(c.detectionPath[1])<<8 |
int(c.detectionPath[2])
}
}
@ -1903,8 +1916,7 @@ func (c *DefaultCtx) Reset(fctx *fasthttp.RequestCtx) {
// Set paths
c.pathOriginal = c.app.getString(fctx.URI().PathOriginal())
// Set method
c.method = c.app.getString(fctx.Request.Header.Method())
c.methodINT = c.app.methodInt(c.method)
c.methodInt = c.app.methodInt(utils.UnsafeString(fctx.Request.Header.Method()))
// Attach *fasthttp.RequestCtx to ctx
c.fasthttp = fctx
// reset base uri
@ -1935,20 +1947,20 @@ func (c *DefaultCtx) getBody() []byte {
}
// Methods to use with next stack.
func (c *DefaultCtx) getMethodINT() int {
return c.methodINT
func (c *DefaultCtx) getMethodInt() int {
return c.methodInt
}
func (c *DefaultCtx) getIndexRoute() int {
return c.indexRoute
}
func (c *DefaultCtx) getTreePath() string {
return c.treePath
func (c *DefaultCtx) getTreePathHash() int {
return c.treePathHash
}
func (c *DefaultCtx) getDetectionPath() string {
return c.detectionPath
return c.app.getString(c.detectionPath)
}
func (c *DefaultCtx) getPathOriginal() string {

View File

@ -17,9 +17,9 @@ type CustomCtx interface {
Reset(fctx *fasthttp.RequestCtx)
// Methods to use with next stack.
getMethodINT() int
getMethodInt() int
getIndexRoute() int
getTreePath() string
getTreePathHash() int
getDetectionPath() string
getPathOriginal() string
getValues() *[maxParams]string
@ -32,10 +32,14 @@ type CustomCtx interface {
func NewDefaultCtx(app *App) *DefaultCtx {
// return ctx
return &DefaultCtx{
ctx := &DefaultCtx{
// Set app reference
app: app,
}
ctx.req = &DefaultReq{ctx: ctx}
ctx.res = &DefaultRes{ctx: ctx}
return ctx
}
func (app *App) newCtx() Ctx {

View File

@ -265,6 +265,12 @@ type Ctx interface {
// We support the following engines: https://github.com/gofiber/template
Render(name string, bind any, layouts ...string) error
renderExtensions(bind any)
// Req returns a convenience type whose API is limited to operations
// on the incoming request.
Req() Req
// Res returns a convenience type whose API is limited to operations
// on the outgoing response.
Res() Res
// Route returns the matched Route struct.
Route() *Route
// SaveFile saves any multipart file to disk.
@ -339,9 +345,9 @@ type Ctx interface {
release()
getBody() []byte
// Methods to use with next stack.
getMethodINT() int
getMethodInt() int
getIndexRoute() int
getTreePath() string
getTreePathHash() int
getDetectionPath() string
getPathOriginal() string
getValues() *[maxParams]string

View File

@ -46,7 +46,7 @@ func Test_Ctx_Accepts(t *testing.T) {
c.Request().Header.Set(HeaderAccept, "text/html,application/xhtml+xml,application/xml;q=0.9")
require.Equal(t, "", c.Accepts(""))
require.Equal(t, "", c.Accepts())
require.Equal(t, "", c.Req().Accepts())
require.Equal(t, ".xml", c.Accepts(".xml"))
require.Equal(t, "", c.Accepts(".john"))
require.Equal(t, "application/xhtml+xml", c.Accepts("application/xml", "application/xml+rss", "application/yaml", "application/xhtml+xml"), "must use client-preferred mime type")
@ -57,13 +57,13 @@ func Test_Ctx_Accepts(t *testing.T) {
c.Request().Header.Set(HeaderAccept, "text/*, application/json")
require.Equal(t, "html", c.Accepts("html"))
require.Equal(t, "text/html", c.Accepts("text/html"))
require.Equal(t, "json", c.Accepts("json", "text"))
require.Equal(t, "json", c.Req().Accepts("json", "text"))
require.Equal(t, "application/json", c.Accepts("application/json"))
require.Equal(t, "", c.Accepts("image/png"))
require.Equal(t, "", c.Accepts("png"))
c.Request().Header.Set(HeaderAccept, "text/html, application/json")
require.Equal(t, "text/*", c.Accepts("text/*"))
require.Equal(t, "text/*", c.Req().Accepts("text/*"))
c.Request().Header.Set(HeaderAccept, "*/*")
require.Equal(t, "html", c.Accepts("html"))
@ -192,7 +192,7 @@ func Test_Ctx_AcceptsCharsets(t *testing.T) {
// go test -v -run=^$ -bench=Benchmark_Ctx_AcceptsCharsets -benchmem -count=4
func Benchmark_Ctx_AcceptsCharsets(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
c.Request().Header.Set("Accept-Charset", "utf-8, iso-8859-1;q=0.5")
var res string
@ -218,7 +218,7 @@ func Test_Ctx_AcceptsEncodings(t *testing.T) {
// go test -v -run=^$ -bench=Benchmark_Ctx_AcceptsEncodings -benchmem -count=4
func Benchmark_Ctx_AcceptsEncodings(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
c.Request().Header.Set(HeaderAcceptEncoding, "deflate, gzip;q=1.0, *;q=0.5")
var res string
@ -243,7 +243,7 @@ func Test_Ctx_AcceptsLanguages(t *testing.T) {
// go test -v -run=^$ -bench=Benchmark_Ctx_AcceptsLanguages -benchmem -count=4
func Benchmark_Ctx_AcceptsLanguages(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
c.Request().Header.Set(HeaderAcceptLanguage, "fr-CH, fr;q=0.9, en;q=0.8, de;q=0.7, *;q=0.5")
var res string
@ -304,7 +304,7 @@ func Test_Ctx_Append(t *testing.T) {
// go test -v -run=^$ -bench=Benchmark_Ctx_Append -benchmem -count=4
func Benchmark_Ctx_Append(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
b.ReportAllocs()
b.ResetTimer()
@ -337,7 +337,7 @@ func Test_Ctx_Attachment(t *testing.T) {
// go test -v -run=^$ -bench=Benchmark_Ctx_Attachment -benchmem -count=4
func Benchmark_Ctx_Attachment(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
b.ReportAllocs()
b.ResetTimer()
@ -363,7 +363,7 @@ func Test_Ctx_BaseURL(t *testing.T) {
// go test -v -run=^$ -bench=Benchmark_Ctx_BaseURL -benchmem
func Benchmark_Ctx_BaseURL(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
c.Request().SetHost("google.com:1337")
c.Request().URI().SetPath("/haha/oke/lol")
@ -380,7 +380,7 @@ func Benchmark_Ctx_BaseURL(b *testing.B) {
func Test_Ctx_Body(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
c.Request().SetBody([]byte("john=doe"))
require.Equal(t, []byte("john=doe"), c.Body())
@ -390,7 +390,7 @@ func Test_Ctx_Body(t *testing.T) {
func Test_Ctx_BodyRaw(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
c.Request().SetBodyRaw([]byte("john=doe"))
require.Equal(t, []byte("john=doe"), c.BodyRaw())
@ -400,7 +400,7 @@ func Test_Ctx_BodyRaw(t *testing.T) {
func Test_Ctx_BodyRaw_Immutable(t *testing.T) {
t.Parallel()
app := New(Config{Immutable: true})
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
c.Request().SetBodyRaw([]byte("john=doe"))
require.Equal(t, []byte("john=doe"), c.BodyRaw())
@ -411,7 +411,7 @@ func Benchmark_Ctx_Body(b *testing.B) {
const input = "john=doe"
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
c.Request().SetBody([]byte(input))
b.ReportAllocs()
@ -428,7 +428,7 @@ func Benchmark_Ctx_BodyRaw(b *testing.B) {
const input = "john=doe"
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
c.Request().SetBodyRaw([]byte(input))
b.ReportAllocs()
@ -445,7 +445,7 @@ func Benchmark_Ctx_BodyRaw_Immutable(b *testing.B) {
const input = "john=doe"
app := New(Config{Immutable: true})
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
c.Request().SetBodyRaw([]byte(input))
b.ReportAllocs()
@ -462,7 +462,7 @@ func Test_Ctx_Body_Immutable(t *testing.T) {
t.Parallel()
app := New()
app.config.Immutable = true
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
c.Request().SetBody([]byte("john=doe"))
require.Equal(t, []byte("john=doe"), c.Body())
@ -474,7 +474,7 @@ func Benchmark_Ctx_Body_Immutable(b *testing.B) {
app := New()
app.config.Immutable = true
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
c.Request().SetBody([]byte(input))
b.ReportAllocs()
@ -527,7 +527,7 @@ func Test_Ctx_Body_With_Compression(t *testing.T) {
t.Run(tCase.name, func(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
c.Request().Header.Set("Content-Encoding", tCase.contentEncoding)
if strings.Contains(tCase.contentEncoding, "gzip") {
@ -720,7 +720,7 @@ func Test_Ctx_Body_With_Compression_Immutable(t *testing.T) {
t.Parallel()
app := New()
app.config.Immutable = true
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
c.Request().Header.Set("Content-Encoding", tCase.contentEncoding)
if strings.Contains(tCase.contentEncoding, "gzip") {
@ -897,7 +897,7 @@ func Test_Ctx_Context(t *testing.T) {
t.Parallel()
testKey := struct{}{}
testValue := "Test Value"
ctx := context.WithValue(context.Background(), testKey, testValue) //nolint: staticcheck // not needed for tests
ctx := context.WithValue(context.Background(), testKey, testValue) //nolint:staticcheck // not needed for tests
require.Equal(t, testValue, ctx.Value(testKey))
})
}
@ -910,7 +910,7 @@ func Test_Ctx_SetContext(t *testing.T) {
testKey := struct{}{}
testValue := "Test Value"
ctx := context.WithValue(context.Background(), testKey, testValue) //nolint: staticcheck // not needed for tests
ctx := context.WithValue(context.Background(), testKey, testValue) //nolint:staticcheck // not needed for tests
c.SetContext(ctx)
require.Equal(t, testValue, c.Context().Value(testKey))
}
@ -930,7 +930,7 @@ func Test_Ctx_Context_Multiple_Requests(t *testing.T) {
}
input := utils.CopyString(Query(c, "input", "NO_VALUE"))
ctx = context.WithValue(ctx, testKey, fmt.Sprintf("%s_%s", testValue, input)) //nolint: staticcheck // not needed for tests
ctx = context.WithValue(ctx, testKey, fmt.Sprintf("%s_%s", testValue, input)) //nolint:staticcheck // not needed for tests
c.SetContext(ctx)
return c.Status(StatusOK).SendString(fmt.Sprintf("resp_%s_returned", input))
@ -968,52 +968,52 @@ func Test_Ctx_Cookie(t *testing.T) {
Expires: expire,
// SameSite: CookieSameSiteStrictMode, // default is "lax"
}
c.Cookie(cookie)
c.Res().Cookie(cookie)
expect := "username=john; expires=" + httpdate + "; path=/; SameSite=Lax"
require.Equal(t, expect, string(c.Response().Header.Peek(HeaderSetCookie)))
require.Equal(t, expect, c.Res().Get(HeaderSetCookie))
expect = "username=john; expires=" + httpdate + "; path=/"
cookie.SameSite = CookieSameSiteDisabled
c.Cookie(cookie)
require.Equal(t, expect, string(c.Response().Header.Peek(HeaderSetCookie)))
c.Res().Cookie(cookie)
require.Equal(t, expect, c.Res().Get(HeaderSetCookie))
expect = "username=john; expires=" + httpdate + "; path=/; SameSite=Strict"
cookie.SameSite = CookieSameSiteStrictMode
c.Cookie(cookie)
require.Equal(t, expect, string(c.Response().Header.Peek(HeaderSetCookie)))
c.Res().Cookie(cookie)
require.Equal(t, expect, c.Res().Get(HeaderSetCookie))
expect = "username=john; expires=" + httpdate + "; path=/; secure; SameSite=None"
cookie.Secure = true
cookie.SameSite = CookieSameSiteNoneMode
c.Cookie(cookie)
require.Equal(t, expect, string(c.Response().Header.Peek(HeaderSetCookie)))
c.Res().Cookie(cookie)
require.Equal(t, expect, c.Res().Get(HeaderSetCookie))
expect = "username=john; path=/; secure; SameSite=None"
// should remove expires and max-age headers
cookie.SessionOnly = true
cookie.Expires = expire
cookie.MaxAge = 10000
c.Cookie(cookie)
require.Equal(t, expect, string(c.Response().Header.Peek(HeaderSetCookie)))
c.Res().Cookie(cookie)
require.Equal(t, expect, c.Res().Get(HeaderSetCookie))
expect = "username=john; path=/; secure; SameSite=None"
// should remove expires and max-age headers when no expire and no MaxAge (default time)
cookie.SessionOnly = false
cookie.Expires = time.Time{}
cookie.MaxAge = 0
c.Cookie(cookie)
require.Equal(t, expect, string(c.Response().Header.Peek(HeaderSetCookie)))
c.Res().Cookie(cookie)
require.Equal(t, expect, c.Res().Get(HeaderSetCookie))
expect = "username=john; path=/; secure; SameSite=None; Partitioned"
cookie.Partitioned = true
c.Cookie(cookie)
require.Equal(t, expect, string(c.Response().Header.Peek(HeaderSetCookie)))
c.Res().Cookie(cookie)
require.Equal(t, expect, c.Res().Get(HeaderSetCookie))
}
// go test -v -run=^$ -bench=Benchmark_Ctx_Cookie -benchmem -count=4
func Benchmark_Ctx_Cookie(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
b.ReportAllocs()
b.ResetTimer()
@ -1033,8 +1033,8 @@ func Test_Ctx_Cookies(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.Set("Cookie", "john=doe")
require.Equal(t, "doe", c.Cookies("john"))
require.Equal(t, "default", c.Cookies("unknown", "default"))
require.Equal(t, "doe", c.Req().Cookies("john"))
require.Equal(t, "default", c.Req().Cookies("unknown", "default"))
}
// go test -run Test_Ctx_Format
@ -1058,13 +1058,13 @@ func Test_Ctx_Format(t *testing.T) {
}
c.Request().Header.Set(HeaderAccept, `text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7`)
err := c.Format(formatHandlers("application/xhtml+xml", "application/xml", "foo/bar")...)
err := c.Res().Format(formatHandlers("application/xhtml+xml", "application/xml", "foo/bar")...)
require.Equal(t, "application/xhtml+xml", accepted)
require.Equal(t, "application/xhtml+xml", c.GetRespHeader(HeaderContentType))
require.NoError(t, err)
require.NotEqual(t, StatusNotAcceptable, c.Response().StatusCode())
err = c.Format(formatHandlers("foo/bar;a=b")...)
err = c.Res().Format(formatHandlers("foo/bar;a=b")...)
require.Equal(t, "foo/bar;a=b", accepted)
require.Equal(t, "foo/bar;a=b", c.GetRespHeader(HeaderContentType))
require.NoError(t, err)
@ -1165,7 +1165,7 @@ func Test_Ctx_AutoFormat(t *testing.T) {
require.Equal(t, "Hello, World!", string(c.Response().Body()))
c.Request().Header.Set(HeaderAccept, MIMETextHTML)
err = c.AutoFormat("Hello, World!")
err = c.Res().AutoFormat("Hello, World!")
require.NoError(t, err)
require.Equal(t, "<p>Hello, World!</p>", string(c.Response().Body()))
@ -1175,7 +1175,7 @@ func Test_Ctx_AutoFormat(t *testing.T) {
require.Equal(t, `"Hello, World!"`, string(c.Response().Body()))
c.Request().Header.Set(HeaderAccept, MIMETextPlain)
err = c.AutoFormat(complex(1, 1))
err = c.Res().AutoFormat(complex(1, 1))
require.NoError(t, err)
require.Equal(t, "(1+1i)", string(c.Response().Body()))
@ -1544,12 +1544,12 @@ func Test_Ctx_Binders(t *testing.T) {
t.Run("URI", func(t *testing.T) {
t.Skip("URI is not ready for v3")
//nolint:gocritic // TODO: uncomment
//t.Parallel()
//withValues(t, func(c Ctx, testStruct *TestStruct) error {
// t.Parallel()
// withValues(t, func(c Ctx, testStruct *TestStruct) error {
// c.Route().Params = []string{"name", "name2", "class", "class2"}
// c.Params().value = [30]string{"foo", "bar", "111", "222"}
// return c.Bind().URI(testStruct)
//})
// })
})
t.Run("ReqHeader", func(t *testing.T) {
t.Parallel()
@ -2566,7 +2566,7 @@ func Test_Ctx_Params_Case_Sensitive(t *testing.T) {
// go test -v -run=^$ -bench=Benchmark_Ctx_Params -benchmem -count=4
func Benchmark_Ctx_Params(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
c.route = &Route{
Params: []string{
@ -2939,7 +2939,7 @@ func Test_Ctx_SaveFile(t *testing.T) {
app := New()
app.Post("/test", func(c Ctx) error {
fh, err := c.FormFile("file")
fh, err := c.Req().FormFile("file")
require.NoError(t, err)
tempFile, err := os.CreateTemp(os.TempDir(), "test-")
@ -3075,7 +3075,7 @@ func Test_Ctx_ClearCookie(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.Set(HeaderCookie, "john=doe")
c.ClearCookie("john")
c.Res().ClearCookie("john")
require.True(t, strings.HasPrefix(string(c.Response().Header.Peek(HeaderSetCookie)), "john=; expires="))
c.Request().Header.Set(HeaderCookie, "test1=dummy")
@ -3104,7 +3104,7 @@ func Test_Ctx_Download(t *testing.T) {
require.Equal(t, expect, c.Response().Body())
require.Equal(t, `attachment; filename="Awesome+File%21"`, string(c.Response().Header.Peek(HeaderContentDisposition)))
require.NoError(t, c.Download("ctx.go"))
require.NoError(t, c.Res().Download("ctx.go"))
require.Equal(t, `attachment; filename="ctx.go"`, string(c.Response().Header.Peek(HeaderContentDisposition)))
}
@ -3136,7 +3136,7 @@ func Test_Ctx_SendFile(t *testing.T) {
// test with custom error code
c = app.AcquireCtx(&fasthttp.RequestCtx{})
err = c.Status(StatusInternalServerError).SendFile("ctx.go")
err = c.Res().Status(StatusInternalServerError).SendFile("ctx.go")
// check expectation
require.NoError(t, err)
require.Equal(t, expectFileContent, c.Response().Body())
@ -3161,7 +3161,7 @@ func Test_Ctx_SendFile_ContentType(t *testing.T) {
// 1) simple case
c := app.AcquireCtx(&fasthttp.RequestCtx{})
err := c.SendFile("./.github/testdata/fs/img/fiber.png")
err := c.Res().SendFile("./.github/testdata/fs/img/fiber.png")
// check expectation
require.NoError(t, err)
require.Equal(t, StatusOK, c.Response().StatusCode())
@ -3746,7 +3746,7 @@ func Benchmark_Ctx_CBOR(b *testing.B) {
func Benchmark_Ctx_JSON_Ctype(b *testing.B) {
app := New()
// TODO: Check extra allocs because of the interface stuff
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
type SomeStruct struct {
Name string
Age uint8
@ -3782,7 +3782,7 @@ func Test_Ctx_JSONP(t *testing.T) {
require.Equal(t, `callback({"Age":20,"Name":"Grame"});`, string(c.Response().Body()))
require.Equal(t, "text/javascript; charset=utf-8", string(c.Response().Header.Peek("content-type")))
err = c.JSONP(Map{
err = c.Res().JSONP(Map{
"Name": "Grame",
"Age": 20,
}, "john")
@ -3813,7 +3813,7 @@ func Test_Ctx_JSONP(t *testing.T) {
// go test -v -run=^$ -bench=Benchmark_Ctx_JSONP -benchmem -count=4
func Benchmark_Ctx_JSONP(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
type SomeStruct struct {
Name string
@ -3838,7 +3838,7 @@ func Benchmark_Ctx_JSONP(b *testing.B) {
func Test_Ctx_XML(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
require.Error(t, c.JSON(complex(1, 1)))
@ -3897,7 +3897,7 @@ func Test_Ctx_XML(t *testing.T) {
// go test -run=^$ -bench=Benchmark_Ctx_XML -benchmem -count=4
func Benchmark_Ctx_XML(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
type SomeStruct struct {
Name string `xml:"Name"`
Age uint8 `xml:"Age"`
@ -3936,7 +3936,7 @@ func Test_Ctx_Links(t *testing.T) {
// go test -v -run=^$ -bench=Benchmark_Ctx_Links -benchmem -count=4
func Benchmark_Ctx_Links(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
b.ReportAllocs()
b.ResetTimer()
@ -4006,7 +4006,7 @@ func Test_Ctx_Render(t *testing.T) {
err = c.Render("./.github/testdata/template-non-exists.html", nil)
require.Error(t, err)
err = c.Render("./.github/testdata/template-invalid.html", nil)
err = c.Res().Render("./.github/testdata/template-invalid.html", nil)
require.Error(t, err)
}
@ -4363,7 +4363,7 @@ func Benchmark_Ctx_Render_Engine(b *testing.B) {
// go test -v -run=^$ -bench=Benchmark_Ctx_Get_Location_From_Route -benchmem -count=4
func Benchmark_Ctx_Get_Location_From_Route(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
app.Get("/user/:name", func(c Ctx) error {
return c.SendString(c.Params("name"))
@ -4578,14 +4578,14 @@ func Test_Ctx_SendStreamWriter(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
err := c.SendStreamWriter(func(w *bufio.Writer) {
w.WriteString("Don't crash please") //nolint:errcheck, revive // It is fine to ignore the error
w.WriteString("Don't crash please") //nolint:errcheck // It is fine to ignore the error
})
require.NoError(t, err)
require.Equal(t, "Don't crash please", string(c.Response().Body()))
err = c.SendStreamWriter(func(w *bufio.Writer) {
for lineNum := 1; lineNum <= 5; lineNum++ {
fmt.Fprintf(w, "Line %d\n", lineNum) //nolint:errcheck, revive // It is fine to ignore the error
fmt.Fprintf(w, "Line %d\n", lineNum)
if err := w.Flush(); err != nil {
t.Errorf("unexpected error: %s", err)
return
@ -4607,7 +4607,7 @@ func Test_Ctx_SendStreamWriter_Interrupted(t *testing.T) {
app.Get("/", func(c Ctx) error {
return c.SendStreamWriter(func(w *bufio.Writer) {
for lineNum := 1; lineNum <= 5; lineNum++ {
fmt.Fprintf(w, "Line %d\n", lineNum) //nolint:errcheck // It is fine to ignore the error
fmt.Fprintf(w, "Line %d\n", lineNum)
if err := w.Flush(); err != nil {
if lineNum < 3 {
@ -4728,7 +4728,7 @@ func Benchmark_Ctx_Type(b *testing.B) {
// go test -v -run=^$ -bench=Benchmark_Ctx_Type_Charset -benchmem -count=4
func Benchmark_Ctx_Type_Charset(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
b.ReportAllocs()
b.ResetTimer()
@ -4753,7 +4753,7 @@ func Test_Ctx_Vary(t *testing.T) {
// go test -v -run=^$ -bench=Benchmark_Ctx_Vary -benchmem -count=4
func Benchmark_Ctx_Vary(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
b.ReportAllocs()
b.ResetTimer()
@ -4806,7 +4806,7 @@ func Test_Ctx_Writef(t *testing.T) {
// go test -v -run=^$ -bench=Benchmark_Ctx_Writef -benchmem -count=4
func Benchmark_Ctx_Writef(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
world := "World!"
b.ReportAllocs()
@ -4907,7 +4907,7 @@ func Test_Ctx_Queries(t *testing.T) {
c.Request().URI().SetQueryString("tags=apple,orange,banana&filters[tags]=apple,orange,banana&filters[category][name]=fruits&filters.tags=apple,orange,banana&filters.category.name=fruits")
queries = c.Queries()
queries = c.Req().Queries()
require.Equal(t, "apple,orange,banana", queries["tags"])
require.Equal(t, "apple,orange,banana", queries["filters[tags]"])
require.Equal(t, "fruits", queries["filters[category][name]"])
@ -4951,11 +4951,11 @@ func Test_Ctx_BodyStreamWriter(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.SetBodyStreamWriter(func(w *bufio.Writer) {
fmt.Fprintf(w, "body writer line 1\n") //nolint: errcheck // It is fine to ignore the error
fmt.Fprintf(w, "body writer line 1\n")
if err := w.Flush(); err != nil {
t.Errorf("unexpected error: %s", err)
}
fmt.Fprintf(w, "body writer line 2\n") //nolint: errcheck // It is fine to ignore the error
fmt.Fprintf(w, "body writer line 2\n")
})
require.True(t, ctx.IsBodyStream())
@ -5055,7 +5055,7 @@ func Test_Ctx_IsFromLocal_X_Forwarded(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.Set(HeaderXForwardedFor, "93.46.8.90")
require.False(t, c.IsFromLocal())
require.False(t, c.Req().IsFromLocal())
}
}
@ -5088,8 +5088,8 @@ func Test_Ctx_IsFromLocal_RemoteAddr(t *testing.T) {
fastCtx := &fasthttp.RequestCtx{}
fastCtx.SetRemoteAddr(localIPv6)
c := app.AcquireCtx(fastCtx)
require.Equal(t, "::1", c.IP())
require.True(t, c.IsFromLocal())
require.Equal(t, "::1", c.Req().IP())
require.True(t, c.Req().IsFromLocal())
}
// Test for the case fasthttp remoteAddr is set to "0:0:0:0:0:0:0:1".
{
@ -5522,6 +5522,10 @@ func Test_GenericParseTypeUints(t *testing.T) {
value: uint(4),
str: "4",
},
{
value: ^uint(0),
str: strconv.FormatUint(uint64(^uint(0)), 10),
},
}
for _, test := range uints {

View File

@ -0,0 +1,9 @@
{
"label": "\uD83D\uDD0C Addon",
"position": 5,
"collapsed": true,
"link": {
"type": "generated-index",
"description": "Addon is an additional useful package that can be used in Fiber."
}
}

126
docs/addon/retry.md Normal file
View File

@ -0,0 +1,126 @@
---
id: retry
---
# Retry Addon
Retry addon for [Fiber](https://github.com/gofiber/fiber) designed to apply retry mechanism for unsuccessful network
operations. This addon uses an exponential backoff algorithm with jitter. It calls the function multiple times and tries
to make it successful. If all calls are failed, then, it returns an error. It adds a jitter at each retry step because adding
a jitter is a way to break synchronization across the client and avoid collision.
## Table of Contents
- [Retry Addon](#retry-addon)
- [Table of Contents](#table-of-contents)
- [Signatures](#signatures)
- [Examples](#examples)
- [Default Config](#default-config)
- [Custom Config](#custom-config)
- [Config](#config)
- [Default Config Example](#default-config-example)
## Signatures
```go
func NewExponentialBackoff(config ...retry.Config) *retry.ExponentialBackoff
```
## Examples
```go
package main
import (
"fmt"
"github.com/gofiber/fiber/v3/addon/retry"
"github.com/gofiber/fiber/v3/client"
)
func main() {
expBackoff := retry.NewExponentialBackoff(retry.Config{})
// Local variables that will be used inside of Retry
var resp *client.Response
var err error
// Retry a network request and return an error to signify to try again
err = expBackoff.Retry(func() error {
client := client.New()
resp, err = client.Get("https://gofiber.io")
if err != nil {
return fmt.Errorf("GET gofiber.io failed: %w", err)
}
if resp.StatusCode() != 200 {
return fmt.Errorf("GET gofiber.io did not return OK 200")
}
return nil
})
// If all retries failed, panic
if err != nil {
panic(err)
}
fmt.Printf("GET gofiber.io succeeded with status code %d\n", resp.StatusCode())
}
```
## Default Config
```go
retry.NewExponentialBackoff()
```
## Custom Config
```go
retry.NewExponentialBackoff(retry.Config{
InitialInterval: 2 * time.Second,
MaxBackoffTime: 64 * time.Second,
Multiplier: 2.0,
MaxRetryCount: 15,
})
```
## Config
```go
// Config defines the config for addon.
type Config struct {
// InitialInterval defines the initial time interval for backoff algorithm.
//
// Optional. Default: 1 * time.Second
InitialInterval time.Duration
// MaxBackoffTime defines maximum time duration for backoff algorithm. When
// the algorithm is reached this time, rest of the retries will be maximum
// 32 seconds.
//
// Optional. Default: 32 * time.Second
MaxBackoffTime time.Duration
// Multiplier defines multiplier number of the backoff algorithm.
//
// Optional. Default: 2.0
Multiplier float64
// MaxRetryCount defines maximum retry count for the backoff algorithm.
//
// Optional. Default: 10
MaxRetryCount int
}
```
## Default Config Example
```go
// DefaultConfig is the default config for retry.
var DefaultConfig = Config{
InitialInterval: 1 * time.Second,
MaxBackoffTime: 32 * time.Second,
Multiplier: 2.0,
MaxRetryCount: 10,
currentInterval: 1 * time.Second,
}
```

View File

@ -135,18 +135,18 @@ func (app *App) Route(path string) Register
```go
type Register interface {
All(handler Handler, middleware ...Handler) Register
Get(handler Handler, middleware ...Handler) Register
Head(handler Handler, middleware ...Handler) Register
Post(handler Handler, middleware ...Handler) Register
Put(handler Handler, middleware ...Handler) Register
Delete(handler Handler, middleware ...Handler) Register
Connect(handler Handler, middleware ...Handler) Register
Options(handler Handler, middleware ...Handler) Register
Trace(handler Handler, middleware ...Handler) Register
Patch(handler Handler, middleware ...Handler) Register
All(handler Handler, handlers ...Handler) Register
Get(handler Handler, handlers ...Handler) Register
Head(handler Handler, handlers ...Handler) Register
Post(handler Handler, handlers ...Handler) Register
Put(handler Handler, handlers ...Handler) Register
Delete(handler Handler, handlers ...Handler) Register
Connect(handler Handler, handlers ...Handler) Register
Options(handler Handler, handlers ...Handler) Register
Trace(handler Handler, handlers ...Handler) Register
Patch(handler Handler, handlers ...Handler) Register
Add(methods []string, handler Handler, middleware ...Handler) Register
Add(methods []string, handler Handler, handlers ...Handler) Register
Route(path string) Register
}

View File

@ -120,6 +120,38 @@ curl -X POST -H "Content-Type: application/x-www-form-urlencoded" --data "name=j
curl -X POST -H "Content-Type: multipart/form-data" -F "name=john" -F "pass=doe" localhost:3000
```
:::info
If you need to bind multipart file, you can use `*multipart.FileHeader`, `*[]*multipart.FileHeader` or `[]*multipart.FileHeader` as a field type.
:::
```go title="Example"
type Person struct {
Name string `form:"name"`
Pass string `form:"pass"`
Avatar *multipart.FileHeader `form:"avatar"`
}
app.Post("/", func(c fiber.Ctx) error {
p := new(Person)
if err := c.Bind().Form(p); err != nil {
return err
}
log.Println(p.Name) // john
log.Println(p.Pass) // doe
log.Println(p.Avatar.Filename) // file.txt
// ...
})
```
Run tests with the following `curl` command:
```bash
curl -X POST -H "Content-Type: multipart/form-data" -F "name=john" -F "pass=doe" -F 'avatar=@filename' localhost:3000
```
### JSON
Binds the request JSON body to a struct.

View File

@ -2,7 +2,7 @@
id: constants
title: 📋 Constants
description: Some constants for Fiber.
sidebar_position: 8
sidebar_position: 9
---
### HTTP methods were copied from net/http

File diff suppressed because it is too large Load Diff

View File

@ -111,11 +111,9 @@ app.Listen(":8080", fiber.ListenConfig{
| <Reference id="enableprefork">EnablePrefork</Reference> | `bool` | When set to true, this will spawn multiple Go processes listening on the same port. | `false` |
| <Reference id="enableprintroutes">EnablePrintRoutes</Reference> | `bool` | If set to true, will print all routes with their method, path, and handler. | `false` |
| <Reference id="gracefulcontext">GracefulContext</Reference> | `context.Context` | Field to shutdown Fiber by given context gracefully. | `nil` |
| <Reference id="ShutdownTimeout">ShutdownTimeout</Reference> | `time.Duration` | Specifies the maximum duration to wait for the server to gracefully shutdown. When the timeout is reached, the graceful shutdown process is interrupted and forcibly terminated, and the `context.DeadlineExceeded` error is passed to the `OnShutdownError` callback. Set to 0 to disable the timeout and wait indefinitely. | `10 * time.Second` |
| <Reference id="ShutdownTimeout">ShutdownTimeout</Reference> | `time.Duration` | Specifies the maximum duration to wait for the server to gracefully shutdown. When the timeout is reached, the graceful shutdown process is interrupted and forcibly terminated, and the `context.DeadlineExceeded` error is passed to the `OnPostShutdown` callback. Set to 0 to disable the timeout and wait indefinitely. | `10 * time.Second` |
| <Reference id="listeneraddrfunc">ListenerAddrFunc</Reference> | `func(addr net.Addr)` | Allows accessing and customizing `net.Listener`. | `nil` |
| <Reference id="listenernetwork">ListenerNetwork</Reference> | `string` | Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only). WARNING: When prefork is set to true, only "tcp4" and "tcp6" can be chosen. | `tcp4` |
| <Reference id="onshutdownerror">OnShutdownError</Reference> | `func(err error)` | Allows to customize error behavior when gracefully shutting down the server by given signal. Prints error with `log.Fatalf()` | `nil` |
| <Reference id="onshutdownsuccess">OnShutdownSuccess</Reference> | `func()` | Allows customizing success behavior when gracefully shutting down the server by given signal. | `nil` |
| <Reference id="tlsconfigfunc">TLSConfigFunc</Reference> | `func(tlsConfig *tls.Config)` | Allows customizing `tls.Config` as you want. | `nil` |
| <Reference id="autocertmanager">AutoCertManager</Reference> | `*autocert.Manager` | Manages TLS certificates automatically using the ACME protocol. Enables integration with Let's Encrypt or other ACME-compatible providers. | `nil` |
| <Reference id="tlsminversion">TLSMinVersion</Reference> | `uint16` | Allows customizing the TLS minimum version. | `tls.VersionTLS12` |
@ -230,7 +228,7 @@ Shutdown gracefully shuts down the server without interrupting any active connec
ShutdownWithTimeout will forcefully close any active connections after the timeout expires.
ShutdownWithContext shuts down the server including by force if the context's deadline is exceeded.
ShutdownWithContext shuts down the server including by force if the context's deadline is exceeded. Shutdown hooks will still be executed, even if an error occurs during the shutdown process, as they are deferred to ensure cleanup happens regardless of errors.
```go
func (app *App) Shutdown() error

View File

@ -15,7 +15,8 @@ With Fiber you can execute custom user functions at specific method execution po
- [OnGroupName](#ongroupname)
- [OnListen](#onlisten)
- [OnFork](#onfork)
- [OnShutdown](#onshutdown)
- [OnPreShutdown](#onpreshutdown)
- [OnPostShutdown](#onpostshutdown)
- [OnMount](#onmount)
## Constants
@ -28,7 +29,8 @@ type OnGroupHandler = func(Group) error
type OnGroupNameHandler = OnGroupHandler
type OnListenHandler = func(ListenData) error
type OnForkHandler = func(int) error
type OnShutdownHandler = func() error
type OnPreShutdownHandler = func() error
type OnPostShutdownHandler = func(error) error
type OnMountHandler = func(*App) error
```
@ -174,12 +176,20 @@ func main() {
func (h *Hooks) OnFork(handler ...OnForkHandler)
```
## OnShutdown
## OnPreShutdown
`OnShutdown` is a hook to execute user functions after shutdown.
`OnPreShutdown` is a hook to execute user functions before shutdown.
```go title="Signature"
func (h *Hooks) OnShutdown(handler ...OnShutdownHandler)
func (h *Hooks) OnPreShutdown(handler ...OnPreShutdownHandler)
```
## OnPostShutdown
`OnPostShutdown` is a hook to execute user functions after shutdown.
```go title="Signature"
func (h *Hooks) OnPostShutdown(handler ...OnPostShutdownHandler)
```
## OnMount

642
docs/api/state.md Normal file
View File

@ -0,0 +1,642 @@
---
id: state
title: 🗂️ State Management
sidebar_position: 8
---
The State Management provides a global keyvalue store for managing application dependencies and runtime data. This store is shared across the entire application and remains consistent between requests. It is implemented using Gos `sync.Map` to ensure safe concurrent access.
## State Type
`State` is a keyvalue store built on top of `sync.Map`. It allows storage and retrieval of dependencies and configurations in a Fiber application as well as threadsafe access to runtime data.
### Definition
```go
// State is a keyvalue store for Fiber's app, used as a global storage for the app's dependencies.
// It is a threadsafe implementation of a map[string]any, using sync.Map.
type State struct {
dependencies sync.Map
}
```
## Methods on State
### Set
Set adds or updates a keyvalue pair in the State.
```go
// Set adds or updates a keyvalue pair in the State.
func (s *State) Set(key string, value any)
```
**Usage Example:**
```go
app.State().Set("appName", "My Fiber App")
```
### Get
Get retrieves a value from the State.
```go title="Signature"
func (s *State) Get(key string) (any, bool)
```
**Usage Example:**
```go
value, ok := app.State().Get("appName")
if ok {
fmt.Println("App Name:", value)
}
```
### MustGet
MustGet retrieves a value from the State and panics if the key is not found.
```go title="Signature"
func (s *State) MustGet(key string) any
```
**Usage Example:**
```go
appName := app.State().MustGet("appName")
fmt.Println("App Name:", appName)
```
### Has
Has checks if a key exists in the State.
```go title="Signature"s
func (s *State) Has(key string) bool
```
**Usage Example:**
```go
if app.State().Has("appName") {
fmt.Println("App Name is set.")
}
```
### Delete
Delete removes a keyvalue pair from the State.
```go title="Signature"
func (s *State) Delete(key string)
```
**Usage Example:**
```go
app.State().Delete("obsoleteKey")
```
### Reset
Reset removes all keys from the State.
```go title="Signature"
func (s *State) Reset()
```
**Usage Example:**
```go
app.State().Reset()
```
### Keys
Keys returns a slice containing all keys present in the State.
```go title="Signature"
func (s *State) Keys() []string
```
**Usage Example:**
```go
keys := app.State().Keys()
fmt.Println("State Keys:", keys)
```
### Len
Len returns the number of keys in the State.
```go
// Len returns the number of keys in the State.
func (s *State) Len() int
```
**Usage Example:**
```go
fmt.Printf("Total State Entries: %d\n", app.State().Len())
```
### GetString
GetString retrieves a string value from the State. It returns the string and a boolean indicating a successful type assertion.
```go title="Signature"
func (s *State) GetString(key string) (string, bool)
```
**Usage Example:**
```go
if appName, ok := app.State().GetString("appName"); ok {
fmt.Println("App Name:", appName)
}
```
### GetInt
GetInt retrieves an integer value from the State. It returns the int and a boolean indicating a successful type assertion.
```go title="Signature"
func (s *State) GetInt(key string) (int, bool)
```
**Usage Example:**
```go
if count, ok := app.State().GetInt("userCount"); ok {
fmt.Printf("User Count: %d\n", count)
}
```
### GetBool
GetBool retrieves a boolean value from the State. It returns the bool and a boolean indicating a successful type assertion.
```go title="Signature"
func (s *State) GetBool(key string) (value, bool)
```
**Usage Example:**
```go
if debug, ok := app.State().GetBool("debugMode"); ok {
fmt.Printf("Debug Mode: %v\n", debug)
}
```
### GetFloat64
GetFloat64 retrieves a float64 value from the State. It returns the float64 and a boolean indicating a successful type assertion.
```go title="Signature"
func (s *State) GetFloat64(key string) (float64, bool)
```
**Usage Example:**
```go title="Signature"
if ratio, ok := app.State().GetFloat64("scalingFactor"); ok {
fmt.Printf("Scaling Factor: %f\n", ratio)
}
```
### GetUint
GetUint retrieves a `uint` value from the State.
```go title="Signature"
func (s *State) GetUint(key string) (uint, bool)
```
**Usage Example:**
```go
if val, ok := app.State().GetUint("maxConnections"); ok {
fmt.Printf("Max Connections: %d\n", val)
}
```
### GetInt8
GetInt8 retrieves an `int8` value from the State.
```go title="Signature"
func (s *State) GetInt8(key string) (int8, bool)
```
**Usage Example:**
```go
if val, ok := app.State().GetInt8("threshold"); ok {
fmt.Printf("Threshold: %d\n", val)
}
```
### GetInt16
GetInt16 retrieves an `int16` value from the State.
```go title="Signature"
func (s *State) GetInt16(key string) (int16, bool)
```
**Usage Example:**
```go
if val, ok := app.State().GetInt16("minValue"); ok {
fmt.Printf("Minimum Value: %d\n", val)
}
```
### GetInt32
GetInt32 retrieves an `int32` value from the State.
```go title="Signature"
func (s *State) GetInt32(key string) (int32, bool)
```
**Usage Example:**
```go
if val, ok := app.State().GetInt32("portNumber"); ok {
fmt.Printf("Port Number: %d\n", val)
}
```
### GetInt64
GetInt64 retrieves an `int64` value from the State.
```go title="Signature"
func (s *State) GetInt64(key string) (int64, bool)
```
**Usage Example:**
```go
if val, ok := app.State().GetInt64("fileSize"); ok {
fmt.Printf("File Size: %d\n", val)
}
```
### GetUint8
GetUint8 retrieves a `uint8` value from the State.
```go title="Signature"
func (s *State) GetUint8(key string) (uint8, bool)
```
**Usage Example:**
```go
if val, ok := app.State().GetUint8("byteValue"); ok {
fmt.Printf("Byte Value: %d\n", val)
}
```
### GetUint16
GetUint16 retrieves a `uint16` value from the State.
```go title="Signature"
func (s *State) GetUint16(key string) (uint16, bool)
```
**Usage Example:**
```go
if val, ok := app.State().GetUint16("limit"); ok {
fmt.Printf("Limit: %d\n", val)
}
```
### GetUint32
GetUint32 retrieves a `uint32` value from the State.
```go title="Signature"
func (s *State) GetUint32(key string) (uint32, bool)
```
**Usage Example:**
```go
if val, ok := app.State().GetUint32("timeout"); ok {
fmt.Printf("Timeout: %d\n", val)
}
```
### GetUint64
GetUint64 retrieves a `uint64` value from the State.
```go title="Signature"
func (s *State) GetUint64(key string) (uint64, bool)
```
**Usage Example:**
```go
if val, ok := app.State().GetUint64("maxSize"); ok {
fmt.Printf("Max Size: %d\n", val)
}
```
### GetUintptr
GetUintptr retrieves a `uintptr` value from the State.
```go title="Signature"
func (s *State) GetUintptr(key string) (uintptr, bool)
```
**Usage Example:**
```go
if val, ok := app.State().GetUintptr("pointerValue"); ok {
fmt.Printf("Pointer Value: %d\n", val)
}
```
### GetFloat32
GetFloat32 retrieves a `float32` value from the State.
```go title="Signature"
func (s *State) GetFloat32(key string) (float32, bool)
```
**Usage Example:**
```go
if val, ok := app.State().GetFloat32("scalingFactor32"); ok {
fmt.Printf("Scaling Factor (float32): %f\n", val)
}
```
### GetComplex64
GetComplex64 retrieves a `complex64` value from the State.
```go title="Signature"
func (s *State) GetComplex64(key string) (complex64, bool)
```
**Usage Example:**
```go
if val, ok := app.State().GetComplex64("complexVal"); ok {
fmt.Printf("Complex Value (complex64): %v\n", val)
}
```
### GetComplex128
GetComplex128 retrieves a `complex128` value from the State.
```go title="Signature"
func (s *State) GetComplex128(key string) (complex128, bool)
```
**Usage Example:**
```go
if val, ok := app.State().GetComplex128("complexVal128"); ok {
fmt.Printf("Complex Value (complex128): %v\n", val)
}
```
## Generic Functions
Fiber provides generic functions to retrieve state values with type safety and fallback options.
### GetState
GetState retrieves a value from the State and casts it to the desired type. It returns the cast value and a boolean indicating if the cast was successful.
```go title="Signature"
func GetState[T any](s *State, key string) (T, bool)
```
**Usage Example:**
```go
// Retrieve an integer value safely.
userCount, ok := GetState[int](app.State(), "userCount")
if ok {
fmt.Printf("User Count: %d\n", userCount)
}
```
### MustGetState
MustGetState retrieves a value from the State and casts it to the desired type. It panics if the key is not found or if the type assertion fails.
```go title="Signature"
func MustGetState[T any](s *State, key string) T
```
**Usage Example:**
```go
// Retrieve the value or panic if it is not present.
config := MustGetState[string](app.State(), "configFile")
fmt.Println("Config File:", config)
```
### GetStateWithDefault
GetStateWithDefault retrieves a value from the State, casting it to the desired type. If the key is not present, it returns the provided default value.
```go title="Signature"
func GetStateWithDefault[T any](s *State, key string, defaultVal T) T
```
**Usage Example:**
```go
// Retrieve a value with a default fallback.
requestCount := GetStateWithDefault[int](app.State(), "requestCount", 0)
fmt.Printf("Request Count: %d\n", requestCount)
```
## Comprehensive Examples
### Example: Request Counter
This example demonstrates how to track the number of requests using the State.
```go
package main
import (
"fmt"
"github.com/gofiber/fiber/v3"
)
func main() {
app := fiber.New()
// Initialize state with a counter.
app.State().Set("requestCount", 0)
// Middleware: Increase counter for every request.
app.Use(func(c fiber.Ctx) error {
count, _ := c.App().State().GetInt("requestCount")
app.State().Set("requestCount", count+1)
return c.Next()
})
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello World!")
})
app.Get("/stats", func(c fiber.Ctx) error {
count, _ := c.App().State().Get("requestCount")
return c.SendString(fmt.Sprintf("Total requests: %d", count))
})
app.Listen(":3000")
}
```
### Example: EnvironmentSpecific Configuration
This example shows how to configure different settings based on the environment.
```go
package main
import (
"os"
"github.com/gofiber/fiber/v3"
)
func main() {
app := fiber.New()
// Determine environment.
environment := os.Getenv("ENV")
if environment == "" {
environment = "development"
}
app.State().Set("environment", environment)
// Set environment-specific configurations.
if environment == "development" {
app.State().Set("apiUrl", "http://localhost:8080/api")
app.State().Set("debug", true)
} else {
app.State().Set("apiUrl", "https://api.production.com")
app.State().Set("debug", false)
}
app.Get("/config", func(c fiber.Ctx) error {
config := map[string]any{
"environment": environment,
"apiUrl": fiber.GetStateWithDefault(c.App().State(), "apiUrl", ""),
"debug": fiber.GetStateWithDefault(c.App().State(), "debug", false),
}
return c.JSON(config)
})
app.Listen(":3000")
}
```
### Example: Dependency Injection with State Management
This example demonstrates how to use the State for dependency injection in a Fiber application.
```go
package main
import (
"context"
"fmt"
"log"
"github.com/gofiber/fiber/v3"
"github.com/redis/go-redis/v9"
)
type User struct {
ID int `query:"id"`
Name string `query:"name"`
Email string `query:"email"`
}
func main() {
app := fiber.New()
ctx := context.Background()
// Initialize Redis client.
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Password: "",
DB: 0,
})
// Check the Redis connection.
if err := rdb.Ping(ctx).Err(); err != nil {
log.Fatalf("Could not connect to Redis: %v", err)
}
// Inject the Redis client into Fiber's State for dependency injection.
app.State().Set("redis", rdb)
app.Get("/user/create", func(c fiber.Ctx) error {
var user User
if err := c.Bind().Query(&user); err != nil {
return c.Status(fiber.StatusBadRequest).SendString(err.Error())
}
// Save the user to the database.
rdb, ok := fiber.GetState[*redis.Client](c.App().State(), "redis")
if !ok {
return c.Status(fiber.StatusInternalServerError).SendString("Redis client not found")
}
// Save the user to the database.
key := fmt.Sprintf("user:%d", user.ID)
err := rdb.HSet(ctx, key, "name", user.Name, "email", user.Email).Err()
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
}
return c.JSON(user)
})
app.Get("/user/:id", func(c fiber.Ctx) error {
id := c.Params("id")
rdb, ok := fiber.GetState[*redis.Client](c.App().State(), "redis")
if !ok {
return c.Status(fiber.StatusInternalServerError).SendString("Redis client not found")
}
key := fmt.Sprintf("user:%s", id)
user, err := rdb.HGetAll(ctx, key).Result()
if err == redis.Nil {
return c.Status(fiber.StatusNotFound).SendString("User not found")
} else if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
}
return c.JSON(user)
})
app.Listen(":3000")
}
```

View File

@ -1,6 +1,6 @@
{
"label": "\uD83C\uDF0E Client",
"position": 5,
"position": 6,
"link": {
"type": "generated-index",
"description": "HTTP client for Fiber."

View File

@ -1,6 +1,6 @@
{
"label": "\uD83E\uDDE9 Extra",
"position": 6,
"position": 8,
"link": {
"type": "generated-index",
"description": "Extra contents for Fiber."

View File

@ -30,7 +30,7 @@ app.Use(func(c fiber.Ctx) error {
})
```
## How can i use live reload ?
## How can I use live reload?
[Air](https://github.com/air-verse/air) is a handy tool that automatically restarts your Go applications whenever the source code changes, making your development process faster and more efficient.
@ -99,10 +99,12 @@ If you have questions or just want to have a chat, feel free to join us via this
![](/img/support-discord.png)
## Does fiber support sub domain routing ?
## Does Fiber support subdomain routing?
Yes we do, here are some examples:
This example works v2
<details>
<summary>Example</summary>
```go
package main
@ -170,4 +172,18 @@ func main() {
}
```
</details>
If more information is needed, please refer to this issue [#750](https://github.com/gofiber/fiber/issues/750)
## How can I handle conversions between Fiber and net/http?
The `adaptor` middleware provides utilities for converting between Fiber and `net/http`. It allows seamless integration of `net/http` handlers, middleware, and requests into Fiber applications, and vice versa.
For details on how to:
* Convert `net/http` handlers to Fiber handlers
* Convert Fiber handlers to `net/http` handlers
* Convert `fiber.Ctx` to `http.Request`
See the dedicated documentation: [Adaptor Documentation](../middleware/adaptor.md).

View File

@ -1,6 +1,6 @@
{
"label": "\uD83D\uDCD6 Guide",
"position": 5,
"position": 7,
"link": {
"type": "generated-index",
"description": "Guides for Fiber."

View File

@ -250,6 +250,10 @@ app.Get("/:test<int>?", func(c fiber.Ctx) error {
Custom constraints can be added to Fiber using the `app.RegisterCustomConstraint` method. Your constraints have to be compatible with the `CustomConstraint` interface.
:::caution
Attention, custom constraints can now override built-in constraints. If a custom constraint has the same name as a built-in constraint, the custom constraint will be used instead. This allows for more flexibility in defining route parameter constraints.
:::
It is a good idea to add external constraints to your project once you want to add more specific rules to your routes.
For example, you can add a constraint to check if a parameter is a valid ULID.

View File

@ -8,8 +8,7 @@ sidebar_position: 5
Fiber provides the [Bind](../api/bind.md#validation) function to validate and bind [request data](../api/bind.md#binders) to a struct.
```go title="Example"
```go title="Basic Example"
import "github.com/go-playground/validator/v10"
type structValidator struct {
@ -42,3 +41,71 @@ app.Post("/", func(c fiber.Ctx) error {
return c.JSON(user)
})
```
```go title="Advanced Validation Example"
type User struct {
Name string `json:"name" validate:"required,min=3,max=32"`
Email string `json:"email" validate:"required,email"`
Age int `json:"age" validate:"gte=0,lte=100"`
Password string `json:"password" validate:"required,min=8"`
Website string `json:"website" validate:"url"`
}
// Custom validation error messages
type UserWithCustomMessages struct {
Name string `json:"name" validate:"required,min=3,max=32" message:"Name is required and must be between 3 and 32 characters"`
Email string `json:"email" validate:"required,email" message:"Valid email is required"`
Age int `json:"age" validate:"gte=0,lte=100" message:"Age must be between 0 and 100"`
}
app.Post("/user", func(c fiber.Ctx) error {
user := new(User)
if err := c.Bind().Body(user); err != nil {
// Handle validation errors
if validationErrors, ok := err.(validator.ValidationErrors); ok {
for _, e := range validationErrors {
// e.Field() - field name
// e.Tag() - validation tag
// e.Value() - invalid value
// e.Param() - validation parameter
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
"field": e.Field(),
"error": e.Error(),
})
}
}
return err
}
return c.JSON(user)
})
```
```go title="Custom Validator Example"
// Custom validator for password strength
type PasswordValidator struct {
validate *validator.Validate
}
func (v *PasswordValidator) Validate(out any) error {
if err := v.validate.Struct(out); err != nil {
return err
}
// Custom password validation logic
if user, ok := out.(*User); ok {
if len(user.Password) < 8 {
return errors.New("password must be at least 8 characters")
}
// Add more password validation rules here
}
return nil
}
// Usage
app := fiber.New(fiber.Config{
StructValidator: &PasswordValidator{validate: validator.New()},
})
```

View File

@ -4,24 +4,34 @@ id: adaptor
# Adaptor
Converter for net/http handlers to/from Fiber request handlers, special thanks to [@arsmn](https://github.com/arsmn)!
The `adaptor` package provides utilities for converting between Fiber and `net/http`. It allows seamless integration of `net/http` handlers, middleware, and requests into Fiber applications, and vice versa.
## Signatures
## Features
| Name | Signature | Description
| :--- | :--- | :---
| HTTPHandler | `HTTPHandler(h http.Handler) fiber.Handler` | http.Handler -> fiber.Handler
| HTTPHandlerFunc | `HTTPHandlerFunc(h http.HandlerFunc) fiber.Handler` | http.HandlerFunc -> fiber.Handler
| HTTPMiddleware | `HTTPHandlerFunc(mw func(http.Handler) http.Handler) fiber.Handler` | func(http.Handler) http.Handler -> fiber.Handler
| FiberHandler | `FiberHandler(h fiber.Handler) http.Handler` | fiber.Handler -> http.Handler
| FiberHandlerFunc | `FiberHandlerFunc(h fiber.Handler) http.HandlerFunc` | fiber.Handler -> http.HandlerFunc
| FiberApp | `FiberApp(app *fiber.App) http.HandlerFunc` | Fiber app -> http.HandlerFunc
| ConvertRequest | `ConvertRequest(c fiber.Ctx, forServer bool) (*http.Request, error)` | fiber.Ctx -> http.Request
| CopyContextToFiberContext | `CopyContextToFiberContext(context any, requestContext *fasthttp.RequestCtx)` | context.Context -> fasthttp.RequestCtx
- Convert `net/http` handlers and middleware to Fiber handlers.
- Convert Fiber handlers to `net/http` handlers.
- Convert Fiber context (`fiber.Ctx`) into an `http.Request`.
## Examples
## API Reference
### net/http to Fiber
| Name | Signature | Description |
|-----------------------------|-------------------------------------------------------------------------------|------------------------------------------------------------------|
| `HTTPHandler` | `HTTPHandler(h http.Handler) fiber.Handler` | Converts `http.Handler` to `fiber.Handler` |
| `HTTPHandlerFunc` | `HTTPHandlerFunc(h http.HandlerFunc) fiber.Handler` | Converts `http.HandlerFunc` to `fiber.Handler` |
| `HTTPMiddleware` | `HTTPMiddleware(mw func(http.Handler) http.Handler) fiber.Handler` | Converts `http.Handler` middleware to `fiber.Handler` middleware |
| `FiberHandler` | `FiberHandler(h fiber.Handler) http.Handler` | Converts `fiber.Handler` to `http.Handler` |
| `FiberHandlerFunc` | `FiberHandlerFunc(h fiber.Handler) http.HandlerFunc` | Converts `fiber.Handler` to `http.HandlerFunc` |
| `FiberApp` | `FiberApp(app *fiber.App) http.HandlerFunc` | Converts an entire Fiber app to a `http.HandlerFunc` |
| `ConvertRequest` | `ConvertRequest(c fiber.Ctx, forServer bool) (*http.Request, error)` | Converts `fiber.Ctx` into a `http.Request` |
| `CopyContextToFiberContext` | `CopyContextToFiberContext(context any, requestContext *fasthttp.RequestCtx)` | Copies `context.Context` to `fasthttp.RequestCtx` |
---
## Usage Examples
### 1. Using `net/http` Handlers in Fiber
This example demonstrates how to use standard `net/http` handlers inside a Fiber application:
```go
package main
@ -29,35 +39,27 @@ package main
import (
"fmt"
"net/http"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/adaptor"
)
func main() {
// New fiber app
app := fiber.New()
// http.Handler -> fiber.Handler
app.Get("/", adaptor.HTTPHandler(handler(greet)))
// Convert a http.Handler to a Fiber handler
app.Get("/", adaptor.HTTPHandler(http.HandlerFunc(helloHandler)))
// http.HandlerFunc -> fiber.Handler
app.Get("/func", adaptor.HTTPHandlerFunc(greet))
// Listen on port 3000
app.Listen(":3000")
}
func handler(f http.HandlerFunc) http.Handler {
return http.HandlerFunc(f)
}
func greet(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "Hello World!")
func helloHandler(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "Hello from net/http!")
}
```
### net/http middleware to Fiber
### 2. Using `net/http` Middleware with Fiber
Middleware written for `net/http` can be used in Fiber:
```go
package main
@ -65,111 +67,119 @@ package main
import (
"log"
"net/http"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/adaptor"
)
func main() {
// New fiber app
app := fiber.New()
// http middleware -> fiber.Handler
app.Use(adaptor.HTTPMiddleware(logMiddleware))
// Apply a http middleware in Fiber
app.Use(adaptor.HTTPMiddleware(loggingMiddleware))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello Fiber!")
})
// Listen on port 3000
app.Listen(":3000")
}
func logMiddleware(next http.Handler) http.Handler {
func loggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Println("log middleware")
log.Println("Request received")
next.ServeHTTP(w, r)
})
}
```
### Fiber Handler to net/http
### 3. Using Fiber Handlers in `net/http`
You can embed Fiber handlers inside `net/http`:
```go
package main
import (
"net/http"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/adaptor"
)
func main() {
// fiber.Handler -> http.Handler
http.Handle("/", adaptor.FiberHandler(greet))
// fiber.Handler -> http.HandlerFunc
http.HandleFunc("/func", adaptor.FiberHandlerFunc(greet))
// Listen on port 3000
// Convert Fiber handler to an http.Handler
http.Handle("/", adaptor.FiberHandler(helloFiber))
// Convert Fiber handler to http.HandlerFunc
http.HandleFunc("/func", adaptor.FiberHandlerFunc(helloFiber))
http.ListenAndServe(":3000", nil)
}
func greet(c fiber.Ctx) error {
return c.SendString("Hello World!")
func helloFiber(c fiber.Ctx) error {
return c.SendString("Hello from Fiber!")
}
```
### Fiber App to net/http
### 4. Running a Fiber App in `net/http`
You can wrap a full Fiber app inside `net/http`:
```go
package main
import (
"net/http"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/adaptor"
)
func main() {
app := fiber.New()
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello from Fiber!")
})
app.Get("/greet", greet)
// Listen on port 3000
// Run Fiber inside an http server
http.ListenAndServe(":3000", adaptor.FiberApp(app))
}
func greet(c fiber.Ctx) error {
return c.SendString("Hello World!")
}
```
### Fiber Context to (net/http).Request
### 5. Converting Fiber Context (`fiber.Ctx`) to `http.Request`
If you need to use a `http.Request` inside a Fiber handler:
```go
package main
import (
"net/http"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/adaptor"
)
func main() {
app := fiber.New()
app.Get("/greet", greetWithHTTPReq)
// Listen on port 3000
http.ListenAndServe(":3000", adaptor.FiberApp(app))
app.Get("/request", handleRequest)
app.Listen(":3000")
}
func greetWithHTTPReq(c fiber.Ctx) error {
func handleRequest(c fiber.Ctx) error {
httpReq, err := adaptor.ConvertRequest(c, false)
if err != nil {
return err
}
return c.SendString("Request URL: " + httpReq.URL.String())
return c.SendString("Converted Request URL: " + httpReq.URL.String())
}
```
---
## Summary
The `adaptor` package allows easy interoperation between Fiber and `net/http`. You can:
- Convert handlers and middleware in both directions.
- Run Fiber apps inside `net/http`.
- Convert `fiber.Ctx` to `http.Request`.
This makes it simple to integrate Fiber with existing Go projects or migrate between frameworks as needed.

View File

@ -27,7 +27,7 @@ Liveness, readiness and startup probes middleware for [Fiber](https://github.com
## Signatures
```go
func NewHealthChecker(config Config) fiber.Handler
func New(config Config) fiber.Handler
```
## Examples
@ -41,38 +41,44 @@ import(
)
```
After you initiate your [Fiber](https://github.com/gofiber/fiber) app, you can use the following possibilities:
After you initiate your [Fiber](https://github.com/gofiber/fiber) app, you can use the following options:
```go
// Provide a minimal config for liveness check
app.Get(healthcheck.DefaultLivenessEndpoint, healthcheck.NewHealthChecker())
app.Get(healthcheck.LivenessEndpoint, healthcheck.New())
// Provide a minimal config for readiness check
app.Get(healthcheck.DefaultReadinessEndpoint, healthcheck.NewHealthChecker())
app.Get(healthcheck.ReadinessEndpoint, healthcheck.New())
// Provide a minimal config for startup check
app.Get(healthcheck.DefaultStartupEndpoint, healthcheck.NewHealthChecker())
app.Get(healthcheck.StartupEndpoint, healthcheck.New())
// Provide a minimal config for check with custom endpoint
app.Get("/live", healthcheck.NewHealthChecker())
app.Get("/live", healthcheck.New())
// Or extend your config for customization
app.Get(healthcheck.DefaultLivenessEndpoint, healthcheck.NewHealthChecker(healthcheck.Config{
app.Get(healthcheck.LivenessEndpoint, healthcheck.New(healthcheck.Config{
Probe: func(c fiber.Ctx) bool {
return true
},
}))
// And it works the same for readiness, just change the route
app.Get(healthcheck.DefaultReadinessEndpoint, healthcheck.NewHealthChecker(healthcheck.Config{
app.Get(healthcheck.ReadinessEndpoint, healthcheck.New(healthcheck.Config{
Probe: func(c fiber.Ctx) bool {
return true
},
}))
// And it works the same for startup, just change the route
app.Get(healthcheck.DefaultStartupEndpoint, healthcheck.NewHealthChecker(healthcheck.Config{
app.Get(healthcheck.StartupEndpoint, healthcheck.New(healthcheck.Config{
Probe: func(c fiber.Ctx) bool {
return true
},
}))
// With a custom route and custom probe
app.Get("/live", healthcheck.NewHealthChecker(healthcheck.Config{
app.Get("/live", healthcheck.New(healthcheck.Config{
Probe: func(c fiber.Ctx) bool {
return true
},
@ -81,7 +87,7 @@ app.Get("/live", healthcheck.NewHealthChecker(healthcheck.Config{
// It can also be used with app.All, although it will only respond to requests with the GET method
// in case of calling the route with any method which isn't GET, the return will be 404 Not Found when app.All is used
// and 405 Method Not Allowed when app.Get is used
app.All(healthcheck.DefaultReadinessEndpoint, healthcheck.NewHealthChecker(healthcheck.Config{
app.All(healthcheck.ReadinessEndpoint, healthcheck.New(healthcheck.Config{
Probe: func(c fiber.Ctx) bool {
return true
},
@ -108,7 +114,7 @@ type Config struct {
// initialization and readiness checks
//
// Optional. Default: func(c fiber.Ctx) bool { return true }
Probe HealthChecker
Probe func(fiber.Ctx) bool
}
```
@ -117,7 +123,7 @@ type Config struct {
The default configuration used by this middleware is defined as follows:
```go
func defaultProbe(fiber.Ctx) bool { return true }
func defaultProbe(_ fiber.Ctx) bool { return true }
var ConfigDefault = Config{
Probe: defaultProbe,

View File

@ -54,7 +54,7 @@ curl -I http://localhost:3000
| ContentSecurityPolicy | `string` | ContentSecurityPolicy | "" |
| CSPReportOnly | `bool` | CSPReportOnly | false |
| HSTSPreloadEnabled | `bool` | HSTSPreloadEnabled | false |
| ReferrerPolicy | `string` | ReferrerPolicy | "ReferrerPolicy" |
| ReferrerPolicy | `string` | ReferrerPolicy | "no-referrer" |
| PermissionPolicy | `string` | Permissions-Policy | "" |
| CrossOriginEmbedderPolicy | `string` | Cross-Origin-Embedder-Policy | "require-corp" |
| CrossOriginOpenerPolicy | `string` | Cross-Origin-Opener-Policy | "same-origin" |

View File

@ -55,13 +55,13 @@ app.Use(logger.New(logger.Config{
}))
// Custom File Writer
file, err := os.OpenFile("./123.log", os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
accessLog, err := os.OpenFile("./access.log", os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
if err != nil {
log.Fatalf("error opening file: %v", err)
log.Fatalf("error opening access.log file: %v", err)
}
defer file.Close()
defer accessLog.Close()
app.Use(logger.New(logger.Config{
Output: file,
Stream: accessLog,
}))
// Add Custom Tags
@ -79,7 +79,7 @@ app.Use(logger.New(logger.Config{
TimeZone: "Asia/Shanghai",
Done: func(c fiber.Ctx, logString []byte) {
if c.Response().StatusCode() != fiber.StatusOK {
reporter.SendToSlack(logString)
reporter.SendToSlack(logString)
}
},
}))
@ -88,6 +88,23 @@ app.Use(logger.New(logger.Config{
app.Use(logger.New(logger.Config{
DisableColors: true,
}))
// Use predefined formats
app.Use(logger.New(logger.Config{
Format: logger.FormatCommon,
}))
app.Use(logger.New(logger.Config{
Format: logger.FormatCombined,
}))
app.Use(logger.New(logger.Config{
Format: logger.FormatJSON,
}))
app.Use(logger.New(logger.Config{
Format: logger.FormatECS,
}))
```
### Use Logger Middleware with Other Loggers
@ -115,7 +132,7 @@ func main() {
// Use the logger middleware with zerolog logger
app.Use(logger.New(logger.Config{
Output: logger.LoggerToWriter(zap, log.LevelDebug),
Stream: logger.LoggerToWriter(zap, log.LevelDebug),
}))
// Define a route
@ -129,45 +146,57 @@ func main() {
```
:::tip
Writing to os.File is goroutine-safe, but if you are using a custom Output that is not goroutine-safe, make sure to implement locking to properly serialize writes.
Writing to os.File is goroutine-safe, but if you are using a custom Stream that is not goroutine-safe, make sure to implement locking to properly serialize writes.
:::
## Config
### Config
| Property | Type | Description | Default |
|:-----------------|:---------------------------|:---------------------------------------------------------------------------------------------------------------------------------|:----------------------------------------------------------------------|
| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` |
| Done | `func(fiber.Ctx, []byte)` | Done is a function that is called after the log string for a request is written to Output, and pass the log string as parameter. | `nil` |
| CustomTags | `map[string]LogFunc` | tagFunctions defines the custom tag action. | `map[string]LogFunc` |
| Format | `string` | Format defines the logging tags. | `[${time}] ${ip} ${status} - ${latency} ${method} ${path} ${error}\n` |
| TimeFormat | `string` | TimeFormat defines the time format for log timestamps. | `15:04:05` |
| TimeZone | `string` | TimeZone can be specified, such as "UTC" and "America/New_York" and "Asia/Chongqing", etc | `"Local"` |
| TimeInterval | `time.Duration` | TimeInterval is the delay before the timestamp is updated. | `500 * time.Millisecond` |
| Output | `io.Writer` | Output is a writer where logs are written. | `os.Stdout` |
| LoggerFunc | `func(c fiber.Ctx, data *Data, cfg Config) error` | Custom logger function for integration with logging libraries (Zerolog, Zap, Logrus, etc). Defaults to Fiber's default logger if not defined. | `see default_logger.go defaultLoggerInstance` |
| DisableColors | `bool` | DisableColors defines if the logs output should be colorized. | `false` |
| enableColors | `bool` | Internal field for enabling colors in the log output. (This is not a user-configurable field) | - |
| enableLatency | `bool` | Internal field for enabling latency measurement in logs. (This is not a user-configurable field) | - |
| timeZoneLocation | `*time.Location` | Internal field for the time zone location. (This is not a user-configurable field) | - |
| Property | Type | Description | Default |
| :------------ | :------------------------------------------------ | :-------------------------------------------------------------------------------------------------------------------------------------------- | :-------------------------------------------------------------------- |
| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` |
| Skip | `func(fiber.Ctx) bool` | Skip is a function to determine if logging is skipped or written to Stream. | `nil` |
| Done | `func(fiber.Ctx, []byte)` | Done is a function that is called after the log string for a request is written to Stream, and pass the log string as parameter. | `nil` |
| CustomTags | `map[string]LogFunc` | tagFunctions defines the custom tag action. | `map[string]LogFunc` |
| `Format` | `string` | Defines the logging tags. See more in [Predefined Formats](#predefined-formats), or create your own using [Tags](#constants). | `[${time}] ${ip} ${status} - ${latency} ${method} ${path} ${error}\n` (same as `DefaultFormat`) |
| TimeFormat | `string` | TimeFormat defines the time format for log timestamps. | `15:04:05` |
| TimeZone | `string` | TimeZone can be specified, such as "UTC" and "America/New_York" and "Asia/Chongqing", etc | `"Local"` |
| TimeInterval | `time.Duration` | TimeInterval is the delay before the timestamp is updated. | `500 * time.Millisecond` |
| Stream | `io.Writer` | Stream is a writer where logs are written. | `os.Stdout` |
| LoggerFunc | `func(c fiber.Ctx, data *Data, cfg Config) error` | Custom logger function for integration with logging libraries (Zerolog, Zap, Logrus, etc). Defaults to Fiber's default logger if not defined. | `see default_logger.go defaultLoggerInstance` |
| DisableColors | `bool` | DisableColors defines if the logs output should be colorized. | `false` |
## Default Config
```go
var ConfigDefault = Config{
Next: nil,
Done: nil,
Format: "[${time}] ${ip} ${status} - ${latency} ${method} ${path} ${error}\n",
TimeFormat: "15:04:05",
TimeZone: "Local",
TimeInterval: 500 * time.Millisecond,
Output: os.Stdout,
DisableColors: false,
LoggerFunc: defaultLoggerInstance,
Next: nil,
Skip: nil,
Done: nil,
Format: DefaultFormat,
TimeFormat: "15:04:05",
TimeZone: "Local",
TimeInterval: 500 * time.Millisecond,
Stream: os.Stdout,
BeforeHandlerFunc: beforeHandlerFunc,
LoggerFunc: defaultLoggerInstance,
enableColors: true,
}
```
## Predefined Formats
Logger provides predefined formats that you can use by name or directly by specifying the format string.
| **Format Constant** | **Format String** | **Description** |
|---------------------|--------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------|
| `DefaultFormat` | `"[${time}] ${ip} ${status} - ${latency} ${method} ${path} ${error}\n"` | Fiber's default logger format. |
| `CommonFormat` | `"${ip} - - [${time}] "${method} ${url} ${protocol}" ${status} ${bytesSent}\n"` | Common Log Format (CLF) used in web server logs. |
| `CombinedFormat` | `"${ip} - - [${time}] "${method} ${url} ${protocol}" ${status} ${bytesSent} "${referer}" "${ua}"\n"` | CLF format plus the `referer` and `user agent` fields. |
| `JSONFormat` | `"{time: ${time}, ip: ${ip}, method: ${method}, url: ${url}, status: ${status}, bytesSent: ${bytesSent}}\n"` | JSON format for structured logging. |
| `ECSFormat` | `"{\"@timestamp\":\"${time}\",\"ecs\":{\"version\":\"1.6.0\"},\"client\":{\"ip\":\"${ip}\"},\"http\":{\"request\":{\"method\":\"${method}\",\"url\":\"${url}\",\"protocol\":\"${protocol}\"},\"response\":{\"status_code\":${status},\"body\":{\"bytes\":${bytesSent}}}},\"log\":{\"level\":\"INFO\",\"logger\":\"fiber\"},\"message\":\"${method} ${url} responded with ${status}\"}\n"` | Elastic Common Schema (ECS) format for structured logging. |
## Constants
```go

View File

@ -9,22 +9,22 @@ Registers a route bound to a specific [HTTP method](https://developer.mozilla.or
```go title="Signatures"
// HTTP methods
func (app *App) Get(path string, handler Handler, middlewares ...Handler) Router
func (app *App) Head(path string, handler Handler, middlewares ...Handler) Router
func (app *App) Post(path string, handler Handler, middlewares ...Handler) Router
func (app *App) Put(path string, handler Handler, middlewares ...Handler) Router
func (app *App) Delete(path string, handler Handler, middlewares ...Handler) Router
func (app *App) Connect(path string, handler Handler, middlewares ...Handler) Router
func (app *App) Options(path string, handler Handler, middlewares ...Handler) Router
func (app *App) Trace(path string, handler Handler, middlewares ...Handler) Router
func (app *App) Patch(path string, handler Handler, middlewares ...Handler) Router
func (app *App) Get(path string, handler Handler, handlers ...Handler) Router
func (app *App) Head(path string, handler Handler, handlers ...Handler) Router
func (app *App) Post(path string, handler Handler, handlers ...Handler) Router
func (app *App) Put(path string, handler Handler, handlers ...Handler) Router
func (app *App) Delete(path string, handler Handler, handlers ...Handler) Router
func (app *App) Connect(path string, handler Handler, handlers ...Handler) Router
func (app *App) Options(path string, handler Handler, handlers ...Handler) Router
func (app *App) Trace(path string, handler Handler, handlers ...Handler) Router
func (app *App) Patch(path string, handler Handler, handlers ...Handler) Router
// Add allows you to specify a method as value
func (app *App) Add(method, path string, handler Handler, middlewares ...Handler) Router
func (app *App) Add(method, path string, handler Handler, handlers ...Handler) Router
// All will register the route on all HTTP methods
// Almost the same as app.Use but not bound to prefixes
func (app *App) All(path string, handler Handler, middlewares ...Handler) Router
func (app *App) All(path string, handler Handler, handlers ...Handler) Router
```
```go title="Examples"
@ -47,9 +47,9 @@ Can be used for middleware packages and prefix catchers. These routes will only
func (app *App) Use(args ...any) Router
// Different usage variations
func (app *App) Use(handler Handler, middlewares ...Handler) Router
func (app *App) Use(path string, handler Handler, middlewares ...Handler) Router
func (app *App) Use(paths []string, handler Handler, middlewares ...Handler) Router
func (app *App) Use(handler Handler, handlers ...Handler) Router
func (app *App) Use(path string, handler Handler, handlers ...Handler) Router
func (app *App) Use(paths []string, handler Handler, handlers ...Handler) Router
func (app *App) Use(path string, app *App) Router
```

View File

@ -16,6 +16,8 @@ In this guide, we'll walk you through the most important changes in Fiber `v3` a
Here's a quick overview of the changes in Fiber `v3`:
- [🚀 App](#-app)
- [🎣 Hooks](#-hooks)
- [🚀 Listen](#-listen)
- [🗺️ Router](#-router)
- [🧠 Context](#-context)
- [📎 Binding](#-binding)
@ -31,6 +33,7 @@ Here's a quick overview of the changes in Fiber `v3`:
- [Filesystem](#filesystem)
- [Monitor](#monitor)
- [Healthcheck](#healthcheck)
- [🔌 Addons](#-addons)
- [📋 Migration guide](#-migration-guide)
## Drop for old Go versions
@ -56,6 +59,7 @@ We have made several changes to the Fiber app, including:
- **RegisterCustomBinder**: Allows for the registration of custom binders.
- **RegisterCustomConstraint**: Allows for the registration of custom constraints.
- **NewCtxFunc**: Introduces a new context function.
- **State**: Provides a global state for the application, which can be used to store and retrieve data across the application. Check out the [State](./api/state) method for further details.
### Removed Methods
@ -158,6 +162,63 @@ app.Listen(":444", fiber.ListenConfig{
})
```
## 🎣 Hooks
We have made several changes to the Fiber hooks, including:
- Added new shutdown hooks to provide better control over the shutdown process:
- `OnPreShutdown` - Executes before the server starts shutting down
- `OnPostShutdown` - Executes after the server has shut down, receives any shutdown error
- Deprecated `OnShutdown` in favor of the new pre/post shutdown hooks
- Improved shutdown hook execution order and reliability
- Added mutex protection for hook registration and execution
Important: When using shutdown hooks, ensure app.Listen() is called in a separate goroutine:
```go
// Correct usage
go app.Listen(":3000")
// ... register shutdown hooks
app.Shutdown()
// Incorrect usage - hooks won't work
app.Listen(":3000") // This blocks
app.Shutdown() // Never reached
```
## 🚀 Listen
We have made several changes to the Fiber listen, including:
- Removed `OnShutdownError` and `OnShutdownSuccess` from `ListenerConfig` in favor of using `OnPostShutdown` hook which receives the shutdown error
```go
app := fiber.New()
// Before - using ListenerConfig callbacks
app.Listen(":3000", fiber.ListenerConfig{
OnShutdownError: func(err error) {
log.Printf("Shutdown error: %v", err)
},
OnShutdownSuccess: func() {
log.Println("Shutdown successful")
},
})
// After - using OnPostShutdown hook
app.Hooks().OnPostShutdown(func(err error) error {
if err != nil {
log.Printf("Shutdown error: %v", err)
} else {
log.Println("Shutdown successful")
}
return nil
})
go app.Listen(":3000")
```
This change simplifies the shutdown handling by consolidating the shutdown callbacks into a single hook that receives the error status.
## 🗺 Router
We have slightly adapted our router interface
@ -487,6 +548,7 @@ Fiber v3 introduces a new binding mechanism that simplifies the process of bindi
- Unified binding from URL parameters, query parameters, headers, and request bodies.
- Support for custom binders and constraints.
- Improved error handling and validation.
- Support multipart file binding for `*multipart.FileHeader`, `*[]*multipart.FileHeader`, and `[]*multipart.FileHeader` field types.
<details>
<summary>Example</summary>
@ -851,6 +913,47 @@ func main() {
</details>
The `Skip` is a function to determine if logging is skipped or written to `Stream`.
<details>
<summary>Example Usage</summary>
```go
app.Use(logger.New(logger.Config{
Skip: func(c fiber.Ctx) bool {
// Skip logging HTTP 200 requests
return c.Response().StatusCode() == fiber.StatusOK
},
}))
```
```go
app.Use(logger.New(logger.Config{
Skip: func(c fiber.Ctx) bool {
// Only log errors, similar to an error.log
return c.Response().StatusCode() < 400
},
}))
```
</details>
#### Predefined Formats
Logger provides predefined formats that you can use by name or directly by specifying the format string.
<details>
<summary>Example Usage</summary>
```go
app.Use(logger.New(logger.Config{
Format: logger.FormatCombined,
}))
```
See more in [Logger](./middleware/logger.md#predefined-formats)
</details>
### Filesystem
We've decided to remove filesystem middleware to clear up the confusion between static and filesystem middleware.
@ -879,6 +982,59 @@ The Healthcheck middleware has been enhanced to support more than two routes, wi
Refer to the [healthcheck middleware migration guide](./middleware/healthcheck.md) or the [general migration guide](#-migration-guide) to review the changes.
## 🔌 Addons
In v3, Fiber introduced Addons. Addons are additional useful packages that can be used in Fiber.
### Retry
The Retry addon is a new addon that implements a retry mechanism for unsuccessful network operations. It uses an exponential backoff algorithm with jitter.
It calls the function multiple times and tries to make it successful. If all calls are failed, then, it returns an error.
It adds a jitter at each retry step because adding a jitter is a way to break synchronization across the client and avoid collision.
<details>
<summary>Example</summary>
```go
package main
import (
"fmt"
"github.com/gofiber/fiber/v3/addon/retry"
"github.com/gofiber/fiber/v3/client"
)
func main() {
expBackoff := retry.NewExponentialBackoff(retry.Config{})
// Local variables that will be used inside of Retry
var resp *client.Response
var err error
// Retry a network request and return an error to signify to try again
err = expBackoff.Retry(func() error {
client := client.New()
resp, err = client.Get("https://gofiber.io")
if err != nil {
return fmt.Errorf("GET gofiber.io failed: %w", err)
}
if resp.StatusCode() != 200 {
return fmt.Errorf("GET gofiber.io did not return OK 200")
}
return nil
})
// If all retries failed, panic
if err != nil {
panic(err)
}
fmt.Printf("GET gofiber.io succeeded with status code %d\n", resp.StatusCode())
}
```
</details>
## 📋 Migration guide
- [🚀 App](#-app-1)
@ -1409,25 +1565,25 @@ With the new version, each health check endpoint is configured separately, allow
// after
// Default liveness endpoint configuration
app.Get(healthcheck.DefaultLivenessEndpoint, healthcheck.NewHealthChecker(healthcheck.Config{
app.Get(healthcheck.LivenessEndpoint, healthcheck.New(healthcheck.Config{
Probe: func(c fiber.Ctx) bool {
return true
},
}))
// Default readiness endpoint configuration
app.Get(healthcheck.DefaultReadinessEndpoint, healthcheck.NewHealthChecker())
app.Get(healthcheck.ReadinessEndpoint, healthcheck.New())
// New default startup endpoint configuration
// Default endpoint is /startupz
app.Get(healthcheck.DefaultStartupEndpoint, healthcheck.NewHealthChecker(healthcheck.Config{
app.Get(healthcheck.StartupEndpoint, healthcheck.New(healthcheck.Config{
Probe: func(c fiber.Ctx) bool {
return serviceA.Ready() && serviceB.Ready() && ...
},
}))
// Custom liveness endpoint configuration
app.Get("/live", healthcheck.NewHealthChecker())
app.Get("/live", healthcheck.New())
```
#### Monitor

View File

@ -40,7 +40,7 @@ var (
ErrNoHandlers = errors.New("format: at least one handler is required, but none were set")
)
// gorilla/schema errors
// gofiber/schema errors
type (
// ConversionError Conversion error exposes the internal schema.ConversionError for public use.
ConversionError = schema.ConversionError

19
go.mod
View File

@ -1,31 +1,30 @@
module github.com/gofiber/fiber/v3
go 1.23
go 1.23.0
require (
github.com/gofiber/schema v1.2.0
github.com/gofiber/utils/v2 v2.0.0-beta.7
github.com/gofiber/schema v1.3.0
github.com/gofiber/utils/v2 v2.0.0-beta.8
github.com/google/uuid v1.6.0
github.com/mattn/go-colorable v0.1.14
github.com/mattn/go-isatty v0.0.20
github.com/stretchr/testify v1.10.0
github.com/tinylib/msgp v1.2.5
github.com/valyala/bytebufferpool v1.0.0
github.com/valyala/fasthttp v1.58.0
golang.org/x/crypto v0.33.0
github.com/valyala/fasthttp v1.59.0
golang.org/x/crypto v0.36.0
)
require (
github.com/andybalholm/brotli v1.1.1 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/fxamacker/cbor/v2 v2.7.0 // direct
github.com/fxamacker/cbor/v2 v2.8.0 // direct
github.com/klauspost/compress v1.17.11 // indirect
github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/valyala/tcplisten v1.0.0 // indirect
github.com/x448/float16 v0.8.4 // indirect
golang.org/x/net v0.33.0 // indirect
golang.org/x/sys v0.30.0 // indirect
golang.org/x/text v0.22.0 // indirect
golang.org/x/net v0.37.0 // indirect
golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.23.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

34
go.sum
View File

@ -2,12 +2,12 @@ github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7X
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E=
github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ=
github.com/gofiber/schema v1.2.0 h1:j+ZRrNnUa/0ZuWrn/6kAtAufEr4jCJ+JuTURAMxNSZg=
github.com/gofiber/schema v1.2.0/go.mod h1:YYwj01w3hVfaNjhtJzaqetymL56VW642YS3qZPhuE6c=
github.com/gofiber/utils/v2 v2.0.0-beta.7 h1:NnHFrRHvhrufPABdWajcKZejz9HnCWmT/asoxRsiEbQ=
github.com/gofiber/utils/v2 v2.0.0-beta.7/go.mod h1:J/M03s+HMdZdvhAeyh76xT72IfVqBzuz/OJkrMa7cwU=
github.com/fxamacker/cbor/v2 v2.8.0 h1:fFtUGXUzXPHTIUdne5+zzMPTfffl3RD5qYnkY40vtxU=
github.com/fxamacker/cbor/v2 v2.8.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
github.com/gofiber/schema v1.3.0 h1:K3F3wYzAY+aivfCCEHPufCthu5/13r/lzp1nuk6mr3Q=
github.com/gofiber/schema v1.3.0/go.mod h1:YYwj01w3hVfaNjhtJzaqetymL56VW642YS3qZPhuE6c=
github.com/gofiber/utils/v2 v2.0.0-beta.8 h1:ZifwbHZqZO3YJsx1ZhDsWnPjaQ7C0YD20LHt+DQeXOU=
github.com/gofiber/utils/v2 v2.0.0-beta.8/go.mod h1:1lCBo9vEF4RFEtTgWntipnaScJZQiM8rrsYycLZ4n9c=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
@ -26,23 +26,21 @@ github.com/tinylib/msgp v1.2.5 h1:WeQg1whrXRFiZusidTQqzETkRpGjFjcIhW6uqWH09po=
github.com/tinylib/msgp v1.2.5/go.mod h1:ykjzy2wzgrlvpDCRc4LA8UXy6D8bzMSuAF3WD57Gok0=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.58.0 h1:GGB2dWxSbEprU9j0iMJHgdKYJVDyjrOwF9RE59PbRuE=
github.com/valyala/fasthttp v1.58.0/go.mod h1:SYXvHHaFp7QZHGKSHmoMipInhrI5StHrhDTYVEjK/Kw=
github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8=
github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc=
github.com/valyala/fasthttp v1.59.0 h1:Qu0qYHfXvPk1mSLNqcFtEk6DpxgA26hy6bmydotDpRI=
github.com/valyala/fasthttp v1.59.0/go.mod h1:GTxNb9Bc6r2a9D0TWNSPwDz78UxnTGBViY3xZNEqyYU=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@ -97,7 +97,7 @@ func (grp *Group) Use(args ...any) Router {
return grp
}
grp.app.register([]string{methodUse}, getGroupPath(grp.Prefix, prefix), grp, nil, handlers...)
grp.app.register([]string{methodUse}, getGroupPath(grp.Prefix, prefix), grp, handlers...)
}
if !grp.anyRouteDefined {
@ -109,60 +109,60 @@ func (grp *Group) Use(args ...any) Router {
// Get registers a route for GET methods that requests a representation
// of the specified resource. Requests using GET should only retrieve data.
func (grp *Group) Get(path string, handler Handler, middleware ...Handler) Router {
return grp.Add([]string{MethodGet}, path, handler, middleware...)
func (grp *Group) Get(path string, handler Handler, handlers ...Handler) Router {
return grp.Add([]string{MethodGet}, path, handler, handlers...)
}
// Head registers a route for HEAD methods that asks for a response identical
// to that of a GET request, but without the response body.
func (grp *Group) Head(path string, handler Handler, middleware ...Handler) Router {
return grp.Add([]string{MethodHead}, path, handler, middleware...)
func (grp *Group) Head(path string, handler Handler, handlers ...Handler) Router {
return grp.Add([]string{MethodHead}, path, handler, handlers...)
}
// Post registers a route for POST methods that is used to submit an entity to the
// specified resource, often causing a change in state or side effects on the server.
func (grp *Group) Post(path string, handler Handler, middleware ...Handler) Router {
return grp.Add([]string{MethodPost}, path, handler, middleware...)
func (grp *Group) Post(path string, handler Handler, handlers ...Handler) Router {
return grp.Add([]string{MethodPost}, path, handler, handlers...)
}
// Put registers a route for PUT methods that replaces all current representations
// of the target resource with the request payload.
func (grp *Group) Put(path string, handler Handler, middleware ...Handler) Router {
return grp.Add([]string{MethodPut}, path, handler, middleware...)
func (grp *Group) Put(path string, handler Handler, handlers ...Handler) Router {
return grp.Add([]string{MethodPut}, path, handler, handlers...)
}
// Delete registers a route for DELETE methods that deletes the specified resource.
func (grp *Group) Delete(path string, handler Handler, middleware ...Handler) Router {
return grp.Add([]string{MethodDelete}, path, handler, middleware...)
func (grp *Group) Delete(path string, handler Handler, handlers ...Handler) Router {
return grp.Add([]string{MethodDelete}, path, handler, handlers...)
}
// Connect registers a route for CONNECT methods that establishes a tunnel to the
// server identified by the target resource.
func (grp *Group) Connect(path string, handler Handler, middleware ...Handler) Router {
return grp.Add([]string{MethodConnect}, path, handler, middleware...)
func (grp *Group) Connect(path string, handler Handler, handlers ...Handler) Router {
return grp.Add([]string{MethodConnect}, path, handler, handlers...)
}
// Options registers a route for OPTIONS methods that is used to describe the
// communication options for the target resource.
func (grp *Group) Options(path string, handler Handler, middleware ...Handler) Router {
return grp.Add([]string{MethodOptions}, path, handler, middleware...)
func (grp *Group) Options(path string, handler Handler, handlers ...Handler) Router {
return grp.Add([]string{MethodOptions}, path, handler, handlers...)
}
// Trace registers a route for TRACE methods that performs a message loop-back
// test along the path to the target resource.
func (grp *Group) Trace(path string, handler Handler, middleware ...Handler) Router {
return grp.Add([]string{MethodTrace}, path, handler, middleware...)
func (grp *Group) Trace(path string, handler Handler, handlers ...Handler) Router {
return grp.Add([]string{MethodTrace}, path, handler, handlers...)
}
// Patch registers a route for PATCH methods that is used to apply partial
// modifications to a resource.
func (grp *Group) Patch(path string, handler Handler, middleware ...Handler) Router {
return grp.Add([]string{MethodPatch}, path, handler, middleware...)
func (grp *Group) Patch(path string, handler Handler, handlers ...Handler) Router {
return grp.Add([]string{MethodPatch}, path, handler, handlers...)
}
// Add allows you to specify multiple HTTP methods to register a route.
func (grp *Group) Add(methods []string, path string, handler Handler, middleware ...Handler) Router {
grp.app.register(methods, getGroupPath(grp.Prefix, path), grp, handler, middleware...)
func (grp *Group) Add(methods []string, path string, handler Handler, handlers ...Handler) Router {
grp.app.register(methods, getGroupPath(grp.Prefix, path), grp, append([]Handler{handler}, handlers...)...)
if !grp.anyRouteDefined {
grp.anyRouteDefined = true
}
@ -171,8 +171,8 @@ func (grp *Group) Add(methods []string, path string, handler Handler, middleware
}
// All will register the handler on all HTTP methods
func (grp *Group) All(path string, handler Handler, middleware ...Handler) Router {
_ = grp.Add(grp.app.config.RequestMethods, path, handler, middleware...)
func (grp *Group) All(path string, handler Handler, handlers ...Handler) Router {
_ = grp.Add(grp.app.config.RequestMethods, path, handler, handlers...)
return grp
}
@ -183,7 +183,7 @@ func (grp *Group) All(path string, handler Handler, middleware ...Handler) Route
func (grp *Group) Group(prefix string, handlers ...Handler) Router {
prefix = getGroupPath(grp.Prefix, prefix)
if len(handlers) > 0 {
grp.app.register([]string{methodUse}, prefix, grp, nil, handlers...)
grp.app.register([]string{methodUse}, prefix, grp, handlers...)
}
// Create new group

View File

@ -14,6 +14,7 @@ import (
"os"
"path/filepath"
"reflect"
"slices"
"strconv"
"strings"
"sync"
@ -51,16 +52,16 @@ func getTLSConfig(ln net.Listener) *tls.Config {
}
// Copy value from pointer
if val := reflect.Indirect(pointer); val.Type() != nil {
if val := reflect.Indirect(pointer); val.IsValid() {
// Get private field from value
if field := val.FieldByName("config"); field.Type() != nil {
if field := val.FieldByName("config"); field.IsValid() {
// Copy value from pointer field (unsafe)
newval := reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())) //nolint:gosec // Probably the only way to extract the *tls.Config from a net.Listener. TODO: Verify there really is no easier way without using unsafe.
if newval.Type() == nil {
if !newval.IsValid() {
return nil
}
// Get element from pointer
if elem := newval.Elem(); elem.Type() != nil {
if elem := newval.Elem(); elem.IsValid() {
// Cast value to *tls.Config
c, ok := elem.Interface().(*tls.Config)
if !ok {
@ -107,15 +108,15 @@ func (app *App) methodExist(c *DefaultCtx) bool {
methods := app.config.RequestMethods
for i := 0; i < len(methods); i++ {
// Skip original method
if c.getMethodINT() == i {
if c.getMethodInt() == i {
continue
}
// Reset stack index
c.setIndexRoute(-1)
tree, ok := c.App().treeStack[i][c.getTreePath()]
tree, ok := c.App().treeStack[i][c.treePathHash]
if !ok {
tree = c.App().treeStack[i][""]
tree = c.App().treeStack[i][0]
}
// Get stack length
lenr := len(tree) - 1
@ -151,15 +152,15 @@ func (app *App) methodExistCustom(c CustomCtx) bool {
methods := app.config.RequestMethods
for i := 0; i < len(methods); i++ {
// Skip original method
if c.getMethodINT() == i {
if c.getMethodInt() == i {
continue
}
// Reset stack index
c.setIndexRoute(-1)
tree, ok := c.App().treeStack[i][c.getTreePath()]
tree, ok := c.App().treeStack[i][c.getTreePathHash()]
if !ok {
tree = c.App().treeStack[i][""]
tree = c.App().treeStack[i][0]
}
// Get stack length
lenr := len(tree) - 1
@ -192,12 +193,10 @@ func (app *App) methodExistCustom(c CustomCtx) bool {
// uniqueRouteStack drop all not unique routes from the slice
func uniqueRouteStack(stack []*Route) []*Route {
var unique []*Route
m := make(map[*Route]int)
m := make(map[*Route]struct{})
for _, v := range stack {
if _, ok := m[v]; !ok {
// Unique key found. Record position and collect
// in result.
m[v] = len(unique)
m[v] = struct{}{}
unique = append(unique, v)
}
}
@ -323,28 +322,23 @@ func getSplicedStrList(headerValue string, dst []string) []string {
return nil
}
var (
index int
character rune
lastElementEndsAt int
insertIndex int
)
for index, character = range headerValue + "$" {
if character == ',' || index == len(headerValue) {
if insertIndex >= len(dst) {
oldSlice := dst
dst = make([]string, len(dst)+(len(dst)>>1)+2)
copy(dst, oldSlice)
}
dst[insertIndex] = utils.TrimLeft(headerValue[lastElementEndsAt:index], ' ')
lastElementEndsAt = index + 1
insertIndex++
dst = dst[:0]
segmentStart := 0
isLeadingSpace := true
for i, c := range headerValue {
switch {
case c == ',':
dst = append(dst, headerValue[segmentStart:i])
segmentStart = i + 1
isLeadingSpace = true
case c == ' ' && isLeadingSpace:
segmentStart = i + 1
default:
isLeadingSpace = false
}
}
dst = append(dst, headerValue[segmentStart:])
if len(dst) > insertIndex {
dst = dst[:insertIndex]
}
return dst
}
@ -490,7 +484,7 @@ func getOffer(header []byte, isAccepted func(spec, offer string, specParams head
if len(acceptedTypes) > 1 {
// Sort accepted types by quality and specificity, preserving order of equal elements
sortAcceptedTypes(&acceptedTypes)
sortAcceptedTypes(acceptedTypes)
}
// Find the first offer that matches the accepted types
@ -518,19 +512,14 @@ func getOffer(header []byte, isAccepted func(spec, offer string, specParams head
// A type with parameters has higher priority than an equivalent one without parameters.
// e.g., text/html;a=1;b=2 comes before text/html;a=1
// See: https://www.rfc-editor.org/rfc/rfc9110#name-content-negotiation-fields
func sortAcceptedTypes(acceptedTypes *[]acceptedType) {
if acceptedTypes == nil || len(*acceptedTypes) < 2 {
return
}
at := *acceptedTypes
func sortAcceptedTypes(at []acceptedType) {
for i := 1; i < len(at); i++ {
lo, hi := 0, i-1
for lo <= hi {
mid := (lo + hi) / 2
if at[i].quality < at[mid].quality ||
(at[i].quality == at[mid].quality && at[i].specificity < at[mid].specificity) ||
(at[i].quality == at[mid].quality && at[i].specificity < at[mid].specificity && len(at[i].params) < len(at[mid].params)) ||
(at[i].quality == at[mid].quality && at[i].specificity == at[mid].specificity && len(at[i].params) < len(at[mid].params)) ||
(at[i].quality == at[mid].quality && at[i].specificity == at[mid].specificity && len(at[i].params) == len(at[mid].params) && at[i].order > at[mid].order) {
lo = mid + 1
} else {
@ -664,39 +653,35 @@ func getBytesImmutable(s string) []byte {
func (app *App) methodInt(s string) int {
// For better performance
if len(app.configured.RequestMethods) == 0 {
// TODO: Use iota instead
switch s {
case MethodGet:
return 0
return methodGet
case MethodHead:
return 1
return methodHead
case MethodPost:
return 2
return methodPost
case MethodPut:
return 3
return methodPut
case MethodDelete:
return 4
return methodDelete
case MethodConnect:
return 5
return methodConnect
case MethodOptions:
return 6
return methodOptions
case MethodTrace:
return 7
return methodTrace
case MethodPatch:
return 8
return methodPatch
default:
return -1
}
}
// For method customization
for i, v := range app.config.RequestMethods {
if s == v {
return i
}
}
return slices.Index(app.config.RequestMethods, s)
}
return -1
func (app *App) method(methodInt int) string {
return app.config.RequestMethods[methodInt]
}
// IsMethodSafe reports whether the HTTP method is considered safe.
@ -796,7 +781,7 @@ func genericParseType[V GenericType](str string, v V, defaultValue ...V) V {
case int64:
return genericParseInt[V](str, 64, func(i int64) V { return assertValueType[V, int64](i) }, defaultValue...)
case uint:
return genericParseUint[V](str, 32, func(i uint64) V { return assertValueType[V, uint](uint(i)) }, defaultValue...)
return genericParseUint[V](str, 0, func(i uint64) V { return assertValueType[V, uint](uint(i)) }, defaultValue...)
case uint8:
return genericParseUint[V](str, 8, func(i uint64) V { return assertValueType[V, uint8](uint8(i)) }, defaultValue...)
case uint16:

View File

@ -303,6 +303,26 @@ func Test_Utils_GetSplicedStrList(t *testing.T) {
headerValue: "gzip,",
expectedList: []string{"gzip", ""},
},
{
description: "has a space between words",
headerValue: " foo bar, hello world",
expectedList: []string{"foo bar", "hello world"},
},
{
description: "single comma",
headerValue: ",",
expectedList: []string{"", ""},
},
{
description: "multiple comma",
headerValue: ",,",
expectedList: []string{"", "", ""},
},
{
description: "comma with space",
headerValue: ", ,",
expectedList: []string{"", "", ""},
},
}
for _, tc := range testCases {
@ -334,7 +354,6 @@ func Test_Utils_SortAcceptedTypes(t *testing.T) {
{spec: "text/html", quality: 1, specificity: 3, order: 0},
{spec: "text/*", quality: 0.5, specificity: 2, order: 1},
{spec: "*/*", quality: 0.1, specificity: 1, order: 2},
{spec: "application/json", quality: 0.999, specificity: 3, order: 3},
{spec: "application/xml", quality: 1, specificity: 3, order: 4},
{spec: "application/pdf", quality: 1, specificity: 3, order: 5},
{spec: "image/png", quality: 1, specificity: 3, order: 6},
@ -343,8 +362,9 @@ func Test_Utils_SortAcceptedTypes(t *testing.T) {
{spec: "image/gif", quality: 1, specificity: 3, order: 9},
{spec: "text/plain", quality: 1, specificity: 3, order: 10},
{spec: "application/json", quality: 0.999, specificity: 3, params: headerParams{"a": []byte("1")}, order: 11},
{spec: "application/json", quality: 0.999, specificity: 3, order: 3},
}
sortAcceptedTypes(&acceptedTypes)
sortAcceptedTypes(acceptedTypes)
require.Equal(t, []acceptedType{
{spec: "text/html", quality: 1, specificity: 3, order: 0},
{spec: "application/xml", quality: 1, specificity: 3, order: 4},
@ -370,7 +390,7 @@ func Benchmark_Utils_SortAcceptedTypes_Sorted(b *testing.B) {
acceptedTypes[0] = acceptedType{spec: "text/html", quality: 1, specificity: 1, order: 0}
acceptedTypes[1] = acceptedType{spec: "text/*", quality: 0.5, specificity: 1, order: 1}
acceptedTypes[2] = acceptedType{spec: "*/*", quality: 0.1, specificity: 1, order: 2}
sortAcceptedTypes(&acceptedTypes)
sortAcceptedTypes(acceptedTypes)
}
require.Equal(b, "text/html", acceptedTypes[0].spec)
require.Equal(b, "text/*", acceptedTypes[1].spec)
@ -394,7 +414,7 @@ func Benchmark_Utils_SortAcceptedTypes_Unsorted(b *testing.B) {
acceptedTypes[8] = acceptedType{spec: "image/*", quality: 1, specificity: 2, order: 8}
acceptedTypes[9] = acceptedType{spec: "image/gif", quality: 1, specificity: 3, order: 9}
acceptedTypes[10] = acceptedType{spec: "text/plain", quality: 1, specificity: 3, order: 10}
sortAcceptedTypes(&acceptedTypes)
sortAcceptedTypes(acceptedTypes)
}
require.Equal(b, []acceptedType{
{spec: "text/html", quality: 1, specificity: 3, order: 0},
@ -546,7 +566,7 @@ func Test_Utils_TestConn_Closed_Write(t *testing.T) {
require.NoError(t, err)
// Close early, write should fail
conn.Close() //nolint:errcheck, revive // It is fine to ignore the error here
conn.Close() //nolint:errcheck // It is fine to ignore the error here
_, err = conn.Write([]byte("Response 2\n"))
require.ErrorIs(t, err, errTestConnClosed)

View File

@ -6,14 +6,15 @@ import (
// OnRouteHandler Handlers define a function to create hooks for Fiber.
type (
OnRouteHandler = func(Route) error
OnNameHandler = OnRouteHandler
OnGroupHandler = func(Group) error
OnGroupNameHandler = OnGroupHandler
OnListenHandler = func(ListenData) error
OnShutdownHandler = func() error
OnForkHandler = func(int) error
OnMountHandler = func(*App) error
OnRouteHandler = func(Route) error
OnNameHandler = OnRouteHandler
OnGroupHandler = func(Group) error
OnGroupNameHandler = OnGroupHandler
OnListenHandler = func(ListenData) error
OnPreShutdownHandler = func() error
OnPostShutdownHandler = func(error) error
OnForkHandler = func(int) error
OnMountHandler = func(*App) error
)
// Hooks is a struct to use it with App.
@ -22,14 +23,15 @@ type Hooks struct {
app *App
// Hooks
onRoute []OnRouteHandler
onName []OnNameHandler
onGroup []OnGroupHandler
onGroupName []OnGroupNameHandler
onListen []OnListenHandler
onShutdown []OnShutdownHandler
onFork []OnForkHandler
onMount []OnMountHandler
onRoute []OnRouteHandler
onName []OnNameHandler
onGroup []OnGroupHandler
onGroupName []OnGroupNameHandler
onListen []OnListenHandler
onPreShutdown []OnPreShutdownHandler
onPostShutdown []OnPostShutdownHandler
onFork []OnForkHandler
onMount []OnMountHandler
}
// ListenData is a struct to use it with OnListenHandler
@ -41,15 +43,16 @@ type ListenData struct {
func newHooks(app *App) *Hooks {
return &Hooks{
app: app,
onRoute: make([]OnRouteHandler, 0),
onGroup: make([]OnGroupHandler, 0),
onGroupName: make([]OnGroupNameHandler, 0),
onName: make([]OnNameHandler, 0),
onListen: make([]OnListenHandler, 0),
onShutdown: make([]OnShutdownHandler, 0),
onFork: make([]OnForkHandler, 0),
onMount: make([]OnMountHandler, 0),
app: app,
onRoute: make([]OnRouteHandler, 0),
onGroup: make([]OnGroupHandler, 0),
onGroupName: make([]OnGroupNameHandler, 0),
onName: make([]OnNameHandler, 0),
onListen: make([]OnListenHandler, 0),
onPreShutdown: make([]OnPreShutdownHandler, 0),
onPostShutdown: make([]OnPostShutdownHandler, 0),
onFork: make([]OnForkHandler, 0),
onMount: make([]OnMountHandler, 0),
}
}
@ -96,10 +99,17 @@ func (h *Hooks) OnListen(handler ...OnListenHandler) {
h.app.mutex.Unlock()
}
// OnShutdown is a hook to execute user functions after Shutdown.
func (h *Hooks) OnShutdown(handler ...OnShutdownHandler) {
// OnPreShutdown is a hook to execute user functions before Shutdown.
func (h *Hooks) OnPreShutdown(handler ...OnPreShutdownHandler) {
h.app.mutex.Lock()
h.onShutdown = append(h.onShutdown, handler...)
h.onPreShutdown = append(h.onPreShutdown, handler...)
h.app.mutex.Unlock()
}
// OnPostShutdown is a hook to execute user functions after Shutdown.
func (h *Hooks) OnPostShutdown(handler ...OnPostShutdownHandler) {
h.app.mutex.Lock()
h.onPostShutdown = append(h.onPostShutdown, handler...)
h.app.mutex.Unlock()
}
@ -191,10 +201,18 @@ func (h *Hooks) executeOnListenHooks(listenData ListenData) error {
return nil
}
func (h *Hooks) executeOnShutdownHooks() {
for _, v := range h.onShutdown {
func (h *Hooks) executeOnPreShutdownHooks() {
for _, v := range h.onPreShutdown {
if err := v(); err != nil {
log.Errorf("failed to call shutdown hook: %v", err)
log.Errorf("failed to call pre shutdown hook: %v", err)
}
}
}
func (h *Hooks) executeOnPostShutdownHooks(err error) {
for _, v := range h.onPostShutdown {
if err := v(err); err != nil {
log.Errorf("failed to call post shutdown hook: %v", err)
}
}
}

View File

@ -2,7 +2,6 @@ package fiber
import (
"errors"
"fmt"
"testing"
"time"
@ -83,17 +82,14 @@ func Test_Hook_OnName(t *testing.T) {
func Test_Hook_OnName_Error(t *testing.T) {
t.Parallel()
app := New()
defer func() {
if err := recover(); err != nil {
require.Equal(t, "unknown error", fmt.Sprintf("%v", err))
}
}()
app.Hooks().OnName(func(_ Route) error {
return errors.New("unknown error")
})
app.Get("/", testSimpleHandler).Name("index")
require.PanicsWithError(t, "unknown error", func() {
app.Get("/", testSimpleHandler).Name("index")
})
}
func Test_Hook_OnGroup(t *testing.T) {
@ -167,36 +163,99 @@ func Test_Hook_OnGroupName(t *testing.T) {
func Test_Hook_OnGroupName_Error(t *testing.T) {
t.Parallel()
app := New()
defer func() {
if err := recover(); err != nil {
require.Equal(t, "unknown error", fmt.Sprintf("%v", err))
}
}()
app.Hooks().OnGroupName(func(_ Group) error {
return errors.New("unknown error")
})
grp := app.Group("/x").Name("x.")
grp.Get("/test", testSimpleHandler)
require.PanicsWithError(t, "unknown error", func() {
_ = app.Group("/x").Name("x.")
})
}
func Test_Hook_OnShutdown(t *testing.T) {
func Test_Hook_OnPrehutdown(t *testing.T) {
t.Parallel()
app := New()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app.Hooks().OnShutdown(func() error {
_, err := buf.WriteString("shutdowning")
app.Hooks().OnPreShutdown(func() error {
_, err := buf.WriteString("pre-shutdowning")
require.NoError(t, err)
return nil
})
require.NoError(t, app.Shutdown())
require.Equal(t, "shutdowning", buf.String())
require.Equal(t, "pre-shutdowning", buf.String())
}
func Test_Hook_OnPostShutdown(t *testing.T) {
t.Run("should execute post shutdown hook with error", func(t *testing.T) {
app := New()
expectedErr := errors.New("test shutdown error")
hookCalled := make(chan error, 1)
defer close(hookCalled)
app.Hooks().OnPostShutdown(func(err error) error {
hookCalled <- err
return nil
})
go func() {
if err := app.Listen(":0"); err != nil {
return
}
}()
time.Sleep(100 * time.Millisecond)
app.hooks.executeOnPostShutdownHooks(expectedErr)
select {
case err := <-hookCalled:
require.Equal(t, expectedErr, err)
case <-time.After(time.Second):
t.Fatal("hook execution timeout")
}
require.NoError(t, app.Shutdown())
})
t.Run("should execute multiple hooks in order", func(t *testing.T) {
app := New()
execution := make([]int, 0)
app.Hooks().OnPostShutdown(func(_ error) error {
execution = append(execution, 1)
return nil
})
app.Hooks().OnPostShutdown(func(_ error) error {
execution = append(execution, 2)
return nil
})
app.hooks.executeOnPostShutdownHooks(nil)
require.Len(t, execution, 2, "expected 2 hooks to execute")
require.Equal(t, []int{1, 2}, execution, "hooks executed in wrong order")
})
t.Run("should handle hook error", func(_ *testing.T) {
app := New()
hookErr := errors.New("hook error")
app.Hooks().OnPostShutdown(func(_ error) error {
return hookErr
})
// Should not panic
app.hooks.executeOnPostShutdownHooks(nil)
})
}
func Test_Hook_OnListen(t *testing.T) {

View File

@ -209,7 +209,7 @@ func Benchmark_Memory_Set(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = testStore.Set("john", []byte("doe"), 0) //nolint: errcheck // error not needed for benchmark
_ = testStore.Set("john", []byte("doe"), 0) //nolint:errcheck // error not needed for benchmark
}
}
@ -220,7 +220,7 @@ func Benchmark_Memory_Set_Parallel(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = testStore.Set("john", []byte("doe"), 0) //nolint: errcheck // error not needed for benchmark
_ = testStore.Set("john", []byte("doe"), 0) //nolint:errcheck // error not needed for benchmark
}
})
}
@ -259,7 +259,7 @@ func Benchmark_Memory_Get(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = testStore.Get("john") //nolint: errcheck // error not needed for benchmark
_, _ = testStore.Get("john") //nolint:errcheck // error not needed for benchmark
}
}
@ -273,7 +273,7 @@ func Benchmark_Memory_Get_Parallel(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, _ = testStore.Get("john") //nolint: errcheck // error not needed for benchmark
_, _ = testStore.Get("john") //nolint:errcheck // error not needed for benchmark
}
})
}
@ -315,8 +315,8 @@ func Benchmark_Memory_SetAndDelete(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = testStore.Set("john", []byte("doe"), 0) //nolint: errcheck // error not needed for benchmark
_ = testStore.Delete("john") //nolint: errcheck // error not needed for benchmark
_ = testStore.Set("john", []byte("doe"), 0) //nolint:errcheck // error not needed for benchmark
_ = testStore.Delete("john") //nolint:errcheck // error not needed for benchmark
}
}
@ -327,8 +327,8 @@ func Benchmark_Memory_SetAndDelete_Parallel(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = testStore.Set("john", []byte("doe"), 0) //nolint: errcheck // error not needed for benchmark
_ = testStore.Delete("john") //nolint: errcheck // error not needed for benchmark
_ = testStore.Set("john", []byte("doe"), 0) //nolint:errcheck // error not needed for benchmark
_ = testStore.Delete("john") //nolint:errcheck // error not needed for benchmark
}
})
}

View File

@ -60,17 +60,6 @@ type ListenConfig struct {
// Default: nil
BeforeServeFunc func(app *App) error `json:"before_serve_func"`
// OnShutdownError allows to customize error behavior when to graceful shutdown server by given signal.
//
// Print error with log.Fatalf() by default.
// Default: nil
OnShutdownError func(err error)
// OnShutdownSuccess allows to customize success behavior when to graceful shutdown server by given signal.
//
// Default: nil
OnShutdownSuccess func()
// AutoCertManager manages TLS certificates automatically using the ACME protocol,
// Enables integration with Let's Encrypt or other ACME-compatible providers.
//
@ -102,7 +91,7 @@ type ListenConfig struct {
CertClientFile string `json:"cert_client_file"`
// When the graceful shutdown begins, use this field to set the timeout
// duration. If the timeout is reached, OnShutdownError will be called.
// duration. If the timeout is reached, OnPostShutdown will be called with the error.
// Set to 0 to disable the timeout and wait indefinitely.
//
// Default: 10 * time.Second
@ -136,9 +125,6 @@ func listenConfigDefault(config ...ListenConfig) ListenConfig {
return ListenConfig{
TLSMinVersion: tls.VersionTLS12,
ListenerNetwork: NetworkTCP4,
OnShutdownError: func(err error) {
log.Fatalf("shutdown: %v", err) //nolint:revive // It's an option
},
ShutdownTimeout: 10 * time.Second,
}
}
@ -148,12 +134,6 @@ func listenConfigDefault(config ...ListenConfig) ListenConfig {
cfg.ListenerNetwork = NetworkTCP4
}
if cfg.OnShutdownError == nil {
cfg.OnShutdownError = func(err error) {
log.Fatalf("shutdown: %v", err) //nolint:revive // It's an option
}
}
if cfg.TLSMinVersion == 0 {
cfg.TLSMinVersion = tls.VersionTLS12
}
@ -348,7 +328,7 @@ func (*App) prepareListenData(addr string, isTLS bool, cfg ListenConfig) ListenD
}
// startupMessage prepares the startup message with the handler number, port, address and other information
func (app *App) startupMessage(addr string, isTLS bool, pids string, cfg ListenConfig) { //nolint: revive // Accepting a bool param named isTLS if fine here
func (app *App) startupMessage(addr string, isTLS bool, pids string, cfg ListenConfig) { //nolint:revive // Accepting a bool param named isTLS if fine here
// ignore child processes
if IsChild() {
return
@ -386,38 +366,35 @@ func (app *App) startupMessage(addr string, isTLS bool, pids string, cfg ListenC
out = colorable.NewNonColorable(os.Stdout)
}
fmt.Fprintf(out, "%s\n", fmt.Sprintf(figletFiberText, colors.Red+"v"+Version+colors.Reset)) //nolint:errcheck,revive // ignore error
fmt.Fprintf(out, strings.Repeat("-", 50)+"\n") //nolint:errcheck,revive,govet // ignore error
fmt.Fprintf(out, "%s\n", fmt.Sprintf(figletFiberText, colors.Red+"v"+Version+colors.Reset))
fmt.Fprintf(out, strings.Repeat("-", 50)+"\n")
if host == "0.0.0.0" {
//nolint:errcheck,revive // ignore error
fmt.Fprintf(out,
"%sINFO%s Server started on: \t%s%s://127.0.0.1:%s%s (bound on host 0.0.0.0 and port %s)\n",
colors.Green, colors.Reset, colors.Blue, scheme, port, colors.Reset, port)
} else {
//nolint:errcheck,revive // ignore error
fmt.Fprintf(out,
"%sINFO%s Server started on: \t%s%s%s\n",
colors.Green, colors.Reset, colors.Blue, fmt.Sprintf("%s://%s:%s", scheme, host, port), colors.Reset)
}
if app.config.AppName != "" {
fmt.Fprintf(out, "%sINFO%s Application name: \t\t%s%s%s\n", colors.Green, colors.Reset, colors.Blue, app.config.AppName, colors.Reset) //nolint:errcheck,revive // ignore error
fmt.Fprintf(out, "%sINFO%s Application name: \t\t%s%s%s\n", colors.Green, colors.Reset, colors.Blue, app.config.AppName, colors.Reset)
}
//nolint:errcheck,revive // ignore error
fmt.Fprintf(out,
"%sINFO%s Total handlers count: \t%s%s%s\n",
colors.Green, colors.Reset, colors.Blue, strconv.Itoa(int(app.handlersCount)), colors.Reset)
if isPrefork == "Enabled" {
fmt.Fprintf(out, "%sINFO%s Prefork: \t\t\t%s%s%s\n", colors.Green, colors.Reset, colors.Blue, isPrefork, colors.Reset) //nolint:errcheck,revive // ignore error
fmt.Fprintf(out, "%sINFO%s Prefork: \t\t\t%s%s%s\n", colors.Green, colors.Reset, colors.Blue, isPrefork, colors.Reset)
} else {
fmt.Fprintf(out, "%sINFO%s Prefork: \t\t\t%s%s%s\n", colors.Green, colors.Reset, colors.Red, isPrefork, colors.Reset) //nolint:errcheck,revive // ignore error
fmt.Fprintf(out, "%sINFO%s Prefork: \t\t\t%s%s%s\n", colors.Green, colors.Reset, colors.Red, isPrefork, colors.Reset)
}
fmt.Fprintf(out, "%sINFO%s PID: \t\t\t%s%v%s\n", colors.Green, colors.Reset, colors.Blue, os.Getpid(), colors.Reset) //nolint:errcheck,revive // ignore error
fmt.Fprintf(out, "%sINFO%s Total process count: \t%s%s%s\n", colors.Green, colors.Reset, colors.Blue, procs, colors.Reset) //nolint:errcheck,revive // ignore error
fmt.Fprintf(out, "%sINFO%s PID: \t\t\t%s%v%s\n", colors.Green, colors.Reset, colors.Blue, os.Getpid(), colors.Reset)
fmt.Fprintf(out, "%sINFO%s Total process count: \t%s%s%s\n", colors.Green, colors.Reset, colors.Blue, procs, colors.Reset)
if cfg.EnablePrefork {
// Turn the `pids` variable (in the form ",a,b,c,d,e,f,etc") into a slice of PIDs
@ -428,7 +405,7 @@ func (app *App) startupMessage(addr string, isTLS bool, pids string, cfg ListenC
}
}
fmt.Fprintf(out, "%sINFO%s Child PIDs: \t\t%s", colors.Green, colors.Reset, colors.Blue) //nolint:errcheck,revive // ignore error
fmt.Fprintf(out, "%sINFO%s Child PIDs: \t\t%s", colors.Green, colors.Reset, colors.Blue)
totalPids := len(pidSlice)
rowTotalPidCount := 10
@ -441,17 +418,17 @@ func (app *App) startupMessage(addr string, isTLS bool, pids string, cfg ListenC
}
for n, pid := range pidSlice[start:end] {
fmt.Fprintf(out, "%s", pid) //nolint:errcheck,revive // ignore error
fmt.Fprintf(out, "%s", pid)
if n+1 != len(pidSlice[start:end]) {
fmt.Fprintf(out, ", ") //nolint:errcheck,revive // ignore error
fmt.Fprintf(out, ", ")
}
}
fmt.Fprintf(out, "\n%s", colors.Reset) //nolint:errcheck,revive // ignore error
fmt.Fprintf(out, "\n%s", colors.Reset)
}
}
// add new Line as spacer
fmt.Fprintf(out, "\n%s", colors.Reset) //nolint:errcheck,revive // ignore error
fmt.Fprintf(out, "\n%s", colors.Reset)
}
// printRoutesMessage print all routes with method, path, name and handlers
@ -493,11 +470,10 @@ func (app *App) printRoutesMessage() {
return routes[i].path < routes[j].path
})
fmt.Fprintf(w, "%smethod\t%s| %spath\t%s| %sname\t%s| %shandlers\t%s\n", colors.Blue, colors.White, colors.Green, colors.White, colors.Cyan, colors.White, colors.Yellow, colors.Reset) //nolint:errcheck,revive // ignore error
fmt.Fprintf(w, "%s------\t%s| %s----\t%s| %s----\t%s| %s--------\t%s\n", colors.Blue, colors.White, colors.Green, colors.White, colors.Cyan, colors.White, colors.Yellow, colors.Reset) //nolint:errcheck,revive // ignore error
fmt.Fprintf(w, "%smethod\t%s| %spath\t%s| %sname\t%s| %shandlers\t%s\n", colors.Blue, colors.White, colors.Green, colors.White, colors.Cyan, colors.White, colors.Yellow, colors.Reset)
fmt.Fprintf(w, "%s------\t%s| %s----\t%s| %s----\t%s| %s--------\t%s\n", colors.Blue, colors.White, colors.Green, colors.White, colors.Cyan, colors.White, colors.Yellow, colors.Reset)
for _, route := range routes {
//nolint:errcheck,revive // ignore error
fmt.Fprintf(w, "%s%s\t%s| %s%s\t%s| %s%s\t%s| %s%s%s\n", colors.Blue, route.method, colors.White, colors.Green, route.path, colors.White, colors.Cyan, route.name, colors.White, colors.Yellow, route.handlers, colors.Reset)
}
@ -517,11 +493,9 @@ func (app *App) gracefulShutdown(ctx context.Context, cfg ListenConfig) {
}
if err != nil {
cfg.OnShutdownError(err)
app.hooks.executeOnPostShutdownHooks(err)
return
}
if success := cfg.OnShutdownSuccess; success != nil {
success()
}
app.hooks.executeOnPostShutdownHooks(nil)
}

View File

@ -15,6 +15,7 @@ import (
"testing"
"time"
"github.com/gofiber/utils/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
@ -37,98 +38,42 @@ func Test_Listen(t *testing.T) {
// go test -run Test_Listen_Graceful_Shutdown
func Test_Listen_Graceful_Shutdown(t *testing.T) {
var mu sync.Mutex
var shutdown bool
app := New()
app.Get("/", func(c Ctx) error {
return c.SendString(c.Hostname())
t.Run("Basic Graceful Shutdown", func(t *testing.T) {
testGracefulShutdown(t, 0)
})
ln := fasthttputil.NewInmemoryListener()
errs := make(chan error)
t.Run("Shutdown With Timeout", func(t *testing.T) {
testGracefulShutdown(t, 500*time.Millisecond)
})
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
errs <- app.Listener(ln, ListenConfig{
DisableStartupMessage: true,
GracefulContext: ctx,
OnShutdownSuccess: func() {
mu.Lock()
shutdown = true
mu.Unlock()
},
})
}()
// Server readiness check
for i := 0; i < 10; i++ {
conn, err := ln.Dial()
if err == nil {
conn.Close() //nolint:errcheck // ignore error
break
}
// Wait a bit before retrying
time.Sleep(100 * time.Millisecond)
if i == 9 {
t.Fatalf("Server did not become ready in time: %v", err)
}
}
testCases := []struct {
ExpectedErr error
ExpectedBody string
Time time.Duration
ExpectedStatusCode int
}{
{Time: 500 * time.Millisecond, ExpectedBody: "example.com", ExpectedStatusCode: StatusOK, ExpectedErr: nil},
{Time: 3 * time.Second, ExpectedBody: "", ExpectedStatusCode: StatusOK, ExpectedErr: fasthttputil.ErrInmemoryListenerClosed},
}
for _, tc := range testCases {
time.Sleep(tc.Time)
req := fasthttp.AcquireRequest()
req.SetRequestURI("http://example.com")
client := fasthttp.HostClient{}
client.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
resp := fasthttp.AcquireResponse()
err := client.Do(req, resp)
require.Equal(t, tc.ExpectedErr, err)
require.Equal(t, tc.ExpectedStatusCode, resp.StatusCode())
require.Equal(t, tc.ExpectedBody, string(resp.Body()))
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
}
mu.Lock()
err := <-errs
require.True(t, shutdown)
require.NoError(t, err)
mu.Unlock()
t.Run("Shutdown With Timeout Error", func(t *testing.T) {
testGracefulShutdown(t, 1*time.Nanosecond)
})
}
// go test -run Test_Listen_Graceful_Shutdown_Timeout
func Test_Listen_Graceful_Shutdown_Timeout(t *testing.T) {
func testGracefulShutdown(t *testing.T, shutdownTimeout time.Duration) {
t.Helper()
var mu sync.Mutex
var shutdownSuccess bool
var shutdownTimeoutError error
var shutdown bool
var receivedErr error
app := New()
app.Get("/", func(c Ctx) error {
time.Sleep(10 * time.Millisecond)
return c.SendString(c.Hostname())
})
ln := fasthttputil.NewInmemoryListener()
errs := make(chan error)
errs := make(chan error, 1)
app.hooks.OnPostShutdown(func(err error) error {
mu.Lock()
defer mu.Unlock()
shutdown = true
receivedErr = err
return nil
})
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
@ -137,93 +82,83 @@ func Test_Listen_Graceful_Shutdown_Timeout(t *testing.T) {
errs <- app.Listener(ln, ListenConfig{
DisableStartupMessage: true,
GracefulContext: ctx,
ShutdownTimeout: 500 * time.Millisecond,
OnShutdownSuccess: func() {
mu.Lock()
shutdownSuccess = true
mu.Unlock()
},
OnShutdownError: func(err error) {
mu.Lock()
shutdownTimeoutError = err
mu.Unlock()
},
ShutdownTimeout: shutdownTimeout,
})
}()
// Server readiness check
for i := 0; i < 10; i++ {
require.Eventually(t, func() bool {
conn, err := ln.Dial()
// To test a graceful shutdown timeout, do not close the connection.
if err == nil {
_ = conn
break
}
// Wait a bit before retrying
time.Sleep(100 * time.Millisecond)
if i == 9 {
t.Fatalf("Server did not become ready in time: %v", err)
if err := conn.Close(); err != nil {
t.Logf("error closing connection: %v", err)
}
return true
}
return false
}, time.Second, 100*time.Millisecond, "Server failed to become ready")
client := fasthttp.HostClient{
Dial: func(_ string) (net.Conn, error) { return ln.Dial() },
}
testCases := []struct {
ExpectedErr error
ExpectedShutdownError error
ExpectedBody string
Time time.Duration
ExpectedStatusCode int
ExpectedShutdownSuccess bool
}{
type testCase struct {
expectedErr error
expectedBody string
name string
waitTime time.Duration
expectedStatusCode int
closeConnection bool
}
testCases := []testCase{
{
Time: 100 * time.Millisecond,
ExpectedBody: "example.com",
ExpectedStatusCode: StatusOK,
ExpectedErr: nil,
ExpectedShutdownError: nil,
ExpectedShutdownSuccess: false,
name: "Server running normally",
waitTime: 500 * time.Millisecond,
expectedBody: "example.com",
expectedStatusCode: StatusOK,
expectedErr: nil,
closeConnection: true,
},
{
Time: 3 * time.Second,
ExpectedBody: "",
ExpectedStatusCode: StatusOK,
ExpectedErr: fasthttputil.ErrInmemoryListenerClosed,
ExpectedShutdownError: context.DeadlineExceeded,
ExpectedShutdownSuccess: false,
name: "Server shutdown complete",
waitTime: 3 * time.Second,
expectedBody: "",
expectedStatusCode: StatusOK,
expectedErr: fasthttputil.ErrInmemoryListenerClosed,
closeConnection: true,
},
}
for _, tc := range testCases {
time.Sleep(tc.Time)
t.Run(tc.name, func(t *testing.T) {
time.Sleep(tc.waitTime)
req := fasthttp.AcquireRequest()
req.SetRequestURI("http://example.com")
req := fasthttp.AcquireRequest()
defer fasthttp.ReleaseRequest(req)
req.SetRequestURI("http://example.com")
client := fasthttp.HostClient{}
client.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseResponse(resp)
resp := fasthttp.AcquireResponse()
err := client.Do(req, resp)
err := client.Do(req, resp)
if err == nil {
require.NoError(t, err)
require.Equal(t, tc.ExpectedStatusCode, resp.StatusCode())
require.Equal(t, tc.ExpectedBody, string(resp.Body()))
} else {
require.ErrorIs(t, err, tc.ExpectedErr)
}
mu.Lock()
require.Equal(t, tc.ExpectedShutdownSuccess, shutdownSuccess)
require.Equal(t, tc.ExpectedShutdownError, shutdownTimeoutError)
mu.Unlock()
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
if tc.expectedErr == nil {
require.NoError(t, err)
require.Equal(t, tc.expectedStatusCode, resp.StatusCode())
require.Equal(t, tc.expectedBody, utils.UnsafeString(resp.Body()))
} else {
require.ErrorIs(t, err, tc.expectedErr)
}
})
}
mu.Lock()
err := <-errs
require.NoError(t, err)
require.True(t, shutdown)
if shutdownTimeout == 1*time.Nanosecond {
require.Error(t, receivedErr)
require.ErrorIs(t, receivedErr, context.DeadlineExceeded)
}
require.NoError(t, <-errs)
mu.Unlock()
}

View File

@ -53,9 +53,9 @@ func (l *defaultLogger) privateLogf(lv Level, format string, fmtArgs []any) {
buf.WriteString(level)
if len(fmtArgs) > 0 {
_, _ = fmt.Fprintf(buf, format, fmtArgs...) //nolint: errcheck // It is fine to ignore the error
_, _ = fmt.Fprintf(buf, format, fmtArgs...)
} else {
_, _ = fmt.Fprint(buf, fmtArgs...) //nolint: errcheck // It is fine to ignore the error
_, _ = fmt.Fprint(buf, fmtArgs...)
}
_ = l.stdlog.Output(l.depth, buf.String()) //nolint:errcheck // It is fine to ignore the error

View File

@ -1,4 +1,4 @@
//nolint:contextcheck, revive // Much easier to just ignore memory leaks in tests
//nolint:contextcheck,revive // Much easier to just ignore memory leaks in tests
package adaptor
import (
@ -68,7 +68,7 @@ func Test_HTTPHandler(t *testing.T) {
w.Header().Set("Header1", "value1")
w.Header().Set("Header2", "value2")
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "request body is %q", body) //nolint:errcheck // not needed
fmt.Fprintf(w, "request body is %q", body)
}
fiberH := HTTPHandlerFunc(http.HandlerFunc(nethttpH))
fiberH = setFiberContextValueMiddleware(fiberH, expectedContextKey, expectedContextValue)

View File

@ -35,17 +35,17 @@ const (
noStore = "no-store"
)
var ignoreHeaders = map[string]any{
"Connection": nil,
"Keep-Alive": nil,
"Proxy-Authenticate": nil,
"Proxy-Authorization": nil,
"TE": nil,
"Trailers": nil,
"Transfer-Encoding": nil,
"Upgrade": nil,
"Content-Type": nil, // already stored explicitly by the cache manager
"Content-Encoding": nil, // already stored explicitly by the cache manager
var ignoreHeaders = map[string]struct{}{
"Connection": {},
"Keep-Alive": {},
"Proxy-Authenticate": {},
"Proxy-Authorization": {},
"TE": {},
"Trailers": {},
"Transfer-Encoding": {},
"Upgrade": {},
"Content-Type": {}, // already stored explicitly by the cache manager
"Content-Encoding": {}, // already stored explicitly by the cache manager
}
var cacheableStatusCodes = map[int]bool{

View File

@ -1333,56 +1333,65 @@ func Test_CSRF_Cookie_Injection_Exploit(t *testing.T) {
}
// TODO: use this test case and make the unsafe header value bug from https://github.com/gofiber/fiber/issues/2045 reproducible and permanently fixed/tested by this testcase
// func Test_CSRF_UnsafeHeaderValue(t *testing.T) {
// t.Parallel()
// app := fiber.New()
func Test_CSRF_UnsafeHeaderValue(t *testing.T) {
t.SkipNow()
t.Parallel()
app := fiber.New()
// app.Use(New())
// app.Get("/", func(c fiber.Ctx) error {
// return c.SendStatus(fiber.StatusOK)
// })
// app.Get("/test", func(c fiber.Ctx) error {
// return c.SendStatus(fiber.StatusOK)
// })
// app.Post("/", func(c fiber.Ctx) error {
// return c.SendStatus(fiber.StatusOK)
// })
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
app.Get("/test", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
// resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
// require.NoError(t, err)
// require.Equal(t, fiber.StatusOK, resp.StatusCode)
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
// var token string
// for _, c := range resp.Cookies() {
// if c.Name != ConfigDefault.CookieName {
// continue
// }
// token = c.Value
// break
// }
var token string
for _, c := range resp.Cookies() {
if c.Name != ConfigDefault.CookieName {
continue
}
token = c.Value
break
}
// fmt.Println("token", token)
t.Log("token", token)
// getReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
// getReq.Header.Set(HeaderName, token)
// resp, err = app.Test(getReq)
getReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
getReq.Header.Set(HeaderName, token)
resp, err = app.Test(getReq)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
// getReq = httptest.NewRequest(fiber.MethodGet, "/test", nil)
// getReq.Header.Set("X-Requested-With", "XMLHttpRequest")
// getReq.Header.Set(fiber.HeaderCacheControl, "no")
// getReq.Header.Set(HeaderName, token)
getReq = httptest.NewRequest(fiber.MethodGet, "/test", nil)
getReq.Header.Set("X-Requested-With", "XMLHttpRequest")
getReq.Header.Set(fiber.HeaderCacheControl, "no")
getReq.Header.Set(HeaderName, token)
// resp, err = app.Test(getReq)
resp, err = app.Test(getReq)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
// getReq.Header.Set(fiber.HeaderAccept, "*/*")
// getReq.Header.Del(HeaderName)
// resp, err = app.Test(getReq)
getReq.Header.Set(fiber.HeaderAccept, "*/*")
getReq.Header.Del(HeaderName)
resp, err = app.Test(getReq)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
// postReq := httptest.NewRequest(fiber.MethodPost, "/", nil)
// postReq.Header.Set("X-Requested-With", "XMLHttpRequest")
// postReq.Header.Set(HeaderName, token)
// resp, err = app.Test(postReq)
// }
postReq := httptest.NewRequest(fiber.MethodPost, "/", nil)
postReq.Header.Set("X-Requested-With", "XMLHttpRequest")
postReq.Header.Set(HeaderName, token)
resp, err = app.Test(postReq)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}
// go test -v -run=^$ -bench=Benchmark_Middleware_CSRF_Check -benchmem -count=4
func Benchmark_Middleware_CSRF_Check(b *testing.B) {

View File

@ -18,18 +18,18 @@ type Config struct {
// the application is in a state where it can handle requests (e.g., the server is up and running).
//
// Optional. Default: func(c fiber.Ctx) bool { return true }
Probe HealthChecker
Probe func(fiber.Ctx) bool
}
const (
DefaultLivenessEndpoint = "/livez"
DefaultReadinessEndpoint = "/readyz"
DefaultStartupEndpoint = "/startupz"
LivenessEndpoint = "/livez"
ReadinessEndpoint = "/readyz"
StartupEndpoint = "/startupz"
)
func defaultProbe(fiber.Ctx) bool { return true }
func defaultProbe(_ fiber.Ctx) bool { return true }
func defaultConfigV3(config ...Config) Config {
func defaultConfig(config ...Config) Config {
if len(config) < 1 {
return Config{
Probe: defaultProbe,

View File

@ -4,11 +4,8 @@ import (
"github.com/gofiber/fiber/v3"
)
// HealthChecker defines a function to check liveness or readiness of the application
type HealthChecker func(fiber.Ctx) bool
func NewHealthChecker(config ...Config) fiber.Handler {
cfg := defaultConfigV3(config...)
func New(config ...Config) fiber.Handler {
cfg := defaultConfig(config...)
return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true

View File

@ -34,9 +34,9 @@ func Test_HealthCheck_Strict_Routing_Default(t *testing.T) {
StrictRouting: true,
})
app.Get(DefaultLivenessEndpoint, NewHealthChecker())
app.Get(DefaultReadinessEndpoint, NewHealthChecker())
app.Get(DefaultStartupEndpoint, NewHealthChecker())
app.Get(LivenessEndpoint, New())
app.Get(ReadinessEndpoint, New())
app.Get(StartupEndpoint, New())
shouldGiveOK(t, app, "/readyz")
shouldGiveOK(t, app, "/livez")
@ -53,9 +53,9 @@ func Test_HealthCheck_Default(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get(DefaultLivenessEndpoint, NewHealthChecker())
app.Get(DefaultReadinessEndpoint, NewHealthChecker())
app.Get(DefaultStartupEndpoint, NewHealthChecker())
app.Get(LivenessEndpoint, New())
app.Get(ReadinessEndpoint, New())
app.Get(StartupEndpoint, New())
shouldGiveOK(t, app, "/readyz")
shouldGiveOK(t, app, "/livez")
@ -73,12 +73,12 @@ func Test_HealthCheck_Custom(t *testing.T) {
app := fiber.New()
c1 := make(chan struct{}, 1)
app.Get("/live", NewHealthChecker(Config{
app.Get("/live", New(Config{
Probe: func(_ fiber.Ctx) bool {
return true
},
}))
app.Get("/ready", NewHealthChecker(Config{
app.Get("/ready", New(Config{
Probe: func(_ fiber.Ctx) bool {
select {
case <-c1:
@ -88,7 +88,7 @@ func Test_HealthCheck_Custom(t *testing.T) {
}
},
}))
app.Get(DefaultStartupEndpoint, NewHealthChecker(Config{
app.Get(StartupEndpoint, New(Config{
Probe: func(_ fiber.Ctx) bool {
return false
},
@ -123,12 +123,12 @@ func Test_HealthCheck_Custom_Nested(t *testing.T) {
app := fiber.New()
c1 := make(chan struct{}, 1)
app.Get("/probe/live", NewHealthChecker(Config{
app.Get("/probe/live", New(Config{
Probe: func(_ fiber.Ctx) bool {
return true
},
}))
app.Get("/probe/ready", NewHealthChecker(Config{
app.Get("/probe/ready", New(Config{
Probe: func(_ fiber.Ctx) bool {
select {
case <-c1:
@ -164,15 +164,15 @@ func Test_HealthCheck_Next(t *testing.T) {
app := fiber.New()
checker := NewHealthChecker(Config{
checker := New(Config{
Next: func(_ fiber.Ctx) bool {
return true
},
})
app.Get(DefaultLivenessEndpoint, checker)
app.Get(DefaultReadinessEndpoint, checker)
app.Get(DefaultStartupEndpoint, checker)
app.Get(LivenessEndpoint, checker)
app.Get(ReadinessEndpoint, checker)
app.Get(StartupEndpoint, checker)
// This should give not found since there are no other handlers to execute
// so it's like the route isn't defined at all
@ -184,9 +184,9 @@ func Test_HealthCheck_Next(t *testing.T) {
func Benchmark_HealthCheck(b *testing.B) {
app := fiber.New()
app.Get(DefaultLivenessEndpoint, NewHealthChecker())
app.Get(DefaultReadinessEndpoint, NewHealthChecker())
app.Get(DefaultStartupEndpoint, NewHealthChecker())
app.Get(LivenessEndpoint, New())
app.Get(ReadinessEndpoint, New())
app.Get(StartupEndpoint, New())
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
@ -206,9 +206,9 @@ func Benchmark_HealthCheck(b *testing.B) {
func Benchmark_HealthCheck_Parallel(b *testing.B) {
app := fiber.New()
app.Get(DefaultLivenessEndpoint, NewHealthChecker())
app.Get(DefaultReadinessEndpoint, NewHealthChecker())
app.Get(DefaultStartupEndpoint, NewHealthChecker())
app.Get(LivenessEndpoint, New())
app.Get(ReadinessEndpoint, New())
app.Get(StartupEndpoint, New())
h := app.Handler()

View File

@ -28,7 +28,7 @@ type Config struct {
ContentSecurityPolicy string
// ReferrerPolicy
// Optional. Default value "ReferrerPolicy".
// Optional. Default value "no-referrer".
ReferrerPolicy string
// Permissions-Policy

View File

@ -10,16 +10,21 @@ import (
// Config defines the config for middleware.
type Config struct {
// Output is a writer where logs are written
// Stream is a writer where logs are written
//
// Default: os.Stdout
Output io.Writer
Stream io.Writer
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c fiber.Ctx) bool
// Skip is a function to determine if logging is skipped or written to Stream.
//
// Optional. Default: nil
Skip func(c fiber.Ctx) bool
// Done is a function that is called after the log string for a request is written to Output,
// and pass the log string as parameter.
//
@ -45,9 +50,23 @@ type Config struct {
timeZoneLocation *time.Location
// Format defines the logging tags
// Format defines the logging format for the middleware.
//
// Optional. Default: [${time}] ${ip} ${status} - ${latency} ${method} ${path} ${error}
// You can customize the log output by defining a format string with placeholders
// such as: ${time}, ${ip}, ${status}, ${method}, ${path}, ${latency}, ${error}, etc.
// The full list of available placeholders can be found in 'tags.go' or at
// 'https://docs.gofiber.io/api/middleware/logger/#constants'.
//
// Fiber provides predefined logging formats that can be used directly:
//
// - DefaultFormat → Uses the default log format: "[${time}] ${ip} ${status} - ${latency} ${method} ${path} ${error}"
// - CommonFormat → Uses the Apache Common Log Format (CLF): "${ip} - - [${time}] \"${method} ${url} ${protocol}\" ${status} ${bytesSent}\n"
// - CombinedFormat → Uses the Apache Combined Log Format: "${ip} - - [${time}] \"${method} ${url} ${protocol}\" ${status} ${bytesSent} \"${referer}\" \"${ua}\"\n"
// - JSONFormat → Uses the JSON log format: "{\"time\":\"${time}\",\"ip\":\"${ip}\",\"method\":\"${method}\",\"url\":\"${url}\",\"status\":${status},\"bytesSent\":${bytesSent}}\n"
// - ECSFormat → Uses the Elastic Common Schema (ECS) log format: {\"@timestamp\":\"${time}\",\"ecs\":{\"version\":\"1.6.0\"},\"client\":{\"ip\":\"${ip}\"},\"http\":{\"request\":{\"method\":\"${method}\",\"url\":\"${url}\",\"protocol\":\"${protocol}\"},\"response\":{\"status_code\":${status},\"body\":{\"bytes\":${bytesSent}}}},\"log\":{\"level\":\"INFO\",\"logger\":\"fiber\"},\"message\":\"${method} ${url} responded with ${status}\"}"
// If both `Format` and `CustomFormat` are provided, the `CustomFormat` will be used, and the `Format` field will be ignored.
// If no format is specified, the default format is used:
// "[${time}] ${ip} ${status} - ${latency} ${method} ${path} ${error}"
Format string
// TimeFormat https://programming.guide/go/format-parse-string-time-date-example.html
@ -98,20 +117,18 @@ type LogFunc func(output Buffer, c fiber.Ctx, data *Data, extraParam string) (in
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
Skip: nil,
Done: nil,
Format: defaultFormat,
Format: DefaultFormat,
TimeFormat: "15:04:05",
TimeZone: "Local",
TimeInterval: 500 * time.Millisecond,
Output: os.Stdout,
Stream: os.Stdout,
BeforeHandlerFunc: beforeHandlerFunc,
LoggerFunc: defaultLoggerInstance,
enableColors: true,
}
// default logging format for Fiber's default logger
var defaultFormat = "[${time}] ${ip} ${status} - ${latency} ${method} ${path} ${error}\n"
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
@ -126,6 +143,9 @@ func configDefault(config ...Config) Config {
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
if cfg.Skip == nil {
cfg.Skip = ConfigDefault.Skip
}
if cfg.Done == nil {
cfg.Done = ConfigDefault.Done
}
@ -141,8 +161,8 @@ func configDefault(config ...Config) Config {
if int(cfg.TimeInterval) <= 0 {
cfg.TimeInterval = ConfigDefault.TimeInterval
}
if cfg.Output == nil {
cfg.Output = ConfigDefault.Output
if cfg.Stream == nil {
cfg.Stream = ConfigDefault.Stream
}
if cfg.BeforeHandlerFunc == nil {
@ -154,7 +174,7 @@ func configDefault(config ...Config) Config {
}
// Enable colors if no custom format or output is given
if !cfg.DisableColors && cfg.Output == ConfigDefault.Output {
if !cfg.DisableColors && cfg.Stream == ConfigDefault.Stream {
cfg.enableColors = true
}

View File

@ -15,6 +15,12 @@ import (
// default logger for fiber
func defaultLoggerInstance(c fiber.Ctx, data *Data, cfg Config) error {
// Check if Skip is defined and call it.
// Now, if Skip(c) == true, we SKIP logging:
if cfg.Skip != nil && cfg.Skip(c) {
return nil // Skip logging if Skip returns true
}
// Alias colors
colors := c.App().Config().ColorScheme
@ -22,7 +28,7 @@ func defaultLoggerInstance(c fiber.Ctx, data *Data, cfg Config) error {
buf := bytebufferpool.Get()
// Default output when no custom Format or io.Writer is given
if cfg.Format == defaultFormat {
if cfg.Format == DefaultFormat {
// Format error if exist
formatErr := ""
if cfg.enableColors {
@ -91,7 +97,7 @@ func defaultLoggerInstance(c fiber.Ctx, data *Data, cfg Config) error {
}
// Write buffer to output
writeLog(cfg.Output, buf.Bytes())
writeLog(cfg.Stream, buf.Bytes())
if cfg.Done != nil {
cfg.Done(c, buf.Bytes())
@ -125,7 +131,7 @@ func defaultLoggerInstance(c fiber.Ctx, data *Data, cfg Config) error {
buf.WriteString(err.Error())
}
writeLog(cfg.Output, buf.Bytes())
writeLog(cfg.Stream, buf.Bytes())
if cfg.Done != nil {
cfg.Done(c, buf.Bytes())
@ -141,9 +147,9 @@ func defaultLoggerInstance(c fiber.Ctx, data *Data, cfg Config) error {
func beforeHandlerFunc(cfg Config) {
// If colors are enabled, check terminal compatibility
if cfg.enableColors {
cfg.Output = colorable.NewColorableStdout()
cfg.Stream = colorable.NewColorableStdout()
if os.Getenv("TERM") == "dumb" || os.Getenv("NO_COLOR") == "1" || (!isatty.IsTerminal(os.Stdout.Fd()) && !isatty.IsCygwinTerminal(os.Stdout.Fd())) {
cfg.Output = colorable.NewNonColorable(os.Stdout)
cfg.Stream = colorable.NewNonColorable(os.Stdout)
}
}
}
@ -160,7 +166,7 @@ func writeLog(w io.Writer, msg []byte) {
// Write error to output
if _, err := w.Write([]byte(err.Error())); err != nil {
// There is something wrong with the given io.Writer
_, _ = fmt.Fprintf(os.Stderr, "Failed to write to log, %v\n", err) //nolint: errcheck // It is fine to ignore the error
_, _ = fmt.Fprintf(os.Stderr, "Failed to write to log, %v\n", err)
}
}
}

View File

@ -0,0 +1,14 @@
package logger
const (
// Fiber's default logger
DefaultFormat = "[${time}] ${ip} ${status} - ${latency} ${method} ${path} ${error}\n"
// Apache Common Log Format (CLF)
CommonFormat = "${ip} - - [${time}] \"${method} ${url} ${protocol}\" ${status} ${bytesSent}\n"
// Apache Combined Log Format
CombinedFormat = "${ip} - - [${time}] \"${method} ${url} ${protocol}\" ${status} ${bytesSent} \"${referer}\" \"${ua}\"\n"
// JSON log formats
JSONFormat = "{\"time\":\"${time}\",\"ip\":\"${ip}\",\"method\":\"${method}\",\"url\":\"${url}\",\"status\":${status},\"bytesSent\":${bytesSent}}\n"
// Elastic Common Schema (ECS) Log Format
ECSFormat = "{\"@timestamp\":\"${time}\",\"ecs\":{\"version\":\"1.6.0\"},\"client\":{\"ip\":\"${ip}\"},\"http\":{\"request\":{\"method\":\"${method}\",\"url\":\"${url}\",\"protocol\":\"${protocol}\"},\"response\":{\"status_code\":${status},\"body\":{\"bytes\":${bytesSent}}}},\"log\":{\"level\":\"INFO\",\"logger\":\"fiber\"},\"message\":\"${method} ${url} responded with ${status}\"}\n"
)

View File

@ -40,7 +40,6 @@ func New(config ...Config) fiber.Handler {
}
}()
}
// Set PID once
pid := strconv.Itoa(os.Getpid())

View File

@ -71,7 +71,7 @@ func Test_Logger(t *testing.T) {
app.Use(New(Config{
Format: "${error}",
Output: buf,
Stream: buf,
}))
app.Get("/", func(_ fiber.Ctx) error {
@ -94,7 +94,7 @@ func Test_Logger_locals(t *testing.T) {
app.Use(New(Config{
Format: "${locals:demo}",
Output: buf,
Stream: buf,
}))
app.Get("/", func(c fiber.Ctx) error {
@ -171,6 +171,147 @@ func Test_Logger_Done(t *testing.T) {
require.Positive(t, buf.Len(), 0)
}
// Test_Logger_Filter tests the Filter functionality of the logger middleware.
// It verifies that logs are written or skipped based on the filter condition.
func Test_Logger_Filter(t *testing.T) {
t.Parallel()
t.Run("Test Not Found", func(t *testing.T) {
t.Parallel()
app := fiber.New()
logOutput := bytes.Buffer{}
// Return true to skip logging for all requests != 404
app.Use(New(Config{
Skip: func(c fiber.Ctx) bool {
return c.Response().StatusCode() != fiber.StatusNotFound
},
Stream: &logOutput,
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/nonexistent", nil))
require.NoError(t, err)
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
// Expect logs for the 404 request
require.Contains(t, logOutput.String(), "404")
})
t.Run("Test OK", func(t *testing.T) {
t.Parallel()
app := fiber.New()
logOutput := bytes.Buffer{}
// Return true to skip logging for all requests == 200
app.Use(New(Config{
Skip: func(c fiber.Ctx) bool {
return c.Response().StatusCode() == fiber.StatusOK
},
Stream: &logOutput,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
// We skip logging for status == 200, so "200" should not appear
require.NotContains(t, logOutput.String(), "200")
})
t.Run("Always Skip", func(t *testing.T) {
t.Parallel()
app := fiber.New()
logOutput := bytes.Buffer{}
// Filter always returns true => skip all logs
app.Use(New(Config{
Skip: func(_ fiber.Ctx) bool {
return true // always skip
},
Stream: &logOutput,
}))
app.Get("/something", func(c fiber.Ctx) error {
return c.Status(fiber.StatusTeapot).SendString("I'm a teapot")
})
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/something", nil))
require.NoError(t, err)
// Expect NO logs
require.Empty(t, logOutput.String())
})
t.Run("Never Skip", func(t *testing.T) {
t.Parallel()
app := fiber.New()
logOutput := bytes.Buffer{}
// Filter always returns false => never skip logs
app.Use(New(Config{
Skip: func(_ fiber.Ctx) bool {
return false // never skip
},
Stream: &logOutput,
}))
app.Get("/always", func(c fiber.Ctx) error {
return c.Status(fiber.StatusTeapot).SendString("Teapot again")
})
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/always", nil))
require.NoError(t, err)
// Expect some logging - check any substring
require.Contains(t, logOutput.String(), strconv.Itoa(fiber.StatusTeapot))
})
t.Run("Skip /healthz", func(t *testing.T) {
t.Parallel()
app := fiber.New()
logOutput := bytes.Buffer{}
// Filter returns true (skip logs) if the request path is /healthz
app.Use(New(Config{
Skip: func(c fiber.Ctx) bool {
return c.Path() == "/healthz"
},
Stream: &logOutput,
}))
// Normal route
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello World!")
})
// Health route
app.Get("/healthz", func(c fiber.Ctx) error {
return c.SendString("OK")
})
// Request to "/" -> should be logged
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
require.NoError(t, err)
require.Contains(t, logOutput.String(), "200")
// Reset output buffer
logOutput.Reset()
// Request to "/healthz" -> should be skipped
_, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/healthz", nil))
require.NoError(t, err)
require.Empty(t, logOutput.String())
})
}
// go test -run Test_Logger_ErrorTimeZone
func Test_Logger_ErrorTimeZone(t *testing.T) {
t.Parallel()
@ -234,7 +375,7 @@ func Test_Logger_LoggerToWriter(t *testing.T) {
app.Use("/"+level, New(Config{
Format: "${error}",
Output: LoggerToWriter(logger, tc.
Stream: LoggerToWriter(logger, tc.
level),
}))
@ -276,7 +417,7 @@ func Test_Logger_ErrorOutput_WithoutColor(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
Output: o,
Stream: o,
DisableColors: true,
}))
@ -293,7 +434,7 @@ func Test_Logger_ErrorOutput(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
Output: o,
Stream: o,
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
@ -312,7 +453,7 @@ func Test_Logger_All(t *testing.T) {
app.Use(New(Config{
Format: "${pid}${reqHeaders}${referer}${scheme}${protocol}${ip}${ips}${host}${url}${ua}${body}${route}${black}${red}${green}${yellow}${blue}${magenta}${cyan}${white}${reset}${error}${reqHeader:test}${query:test}${form:test}${cookie:test}${non}",
Output: buf,
Stream: buf,
}))
// Alias colors
@ -326,6 +467,124 @@ func Test_Logger_All(t *testing.T) {
require.Equal(t, expected, buf.String())
}
func Test_Logger_CLF_Format(t *testing.T) {
t.Parallel()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app := fiber.New()
app.Use(New(Config{
Format: CommonFormat,
Stream: buf,
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?foo=bar", nil))
require.NoError(t, err)
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
expected := fmt.Sprintf("0.0.0.0 - - [%s] \"%s %s %s\" %d %d\n",
time.Now().Format("15:04:05"),
fiber.MethodGet, "/?foo=bar", "HTTP/1.1",
fiber.StatusNotFound,
0)
logResponse := buf.String()
require.Equal(t, expected, logResponse)
}
func Test_Logger_Combined_CLF_Format(t *testing.T) {
t.Parallel()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app := fiber.New()
app.Use(New(Config{
Format: CombinedFormat,
Stream: buf,
}))
const expectedUA = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/74.0.3729.169 Safari/537.36"
const expectedReferer = "http://example.com"
req := httptest.NewRequest(fiber.MethodGet, "/?foo=bar", nil)
req.Header.Set("Referer", expectedReferer)
req.Header.Set("User-Agent", expectedUA)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
expected := fmt.Sprintf("0.0.0.0 - - [%s] %q %d %d %q %q\n",
time.Now().Format("15:04:05"),
fmt.Sprintf("%s %s %s", fiber.MethodGet, "/?foo=bar", "HTTP/1.1"),
fiber.StatusNotFound,
0,
expectedReferer,
expectedUA)
logResponse := buf.String()
require.Equal(t, expected, logResponse)
}
func Test_Logger_Json_Format(t *testing.T) {
t.Parallel()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app := fiber.New()
app.Use(New(Config{
Format: JSONFormat,
Stream: buf,
}))
req := httptest.NewRequest(fiber.MethodGet, "/?foo=bar", nil)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
expected := fmt.Sprintf(
"{\"time\":%q,\"ip\":%q,\"method\":%q,\"url\":%q,\"status\":%d,\"bytesSent\":%d}\n",
time.Now().Format("15:04:05"),
"0.0.0.0",
fiber.MethodGet,
"/?foo=bar",
fiber.StatusNotFound,
0,
)
logResponse := buf.String()
require.Equal(t, expected, logResponse)
}
func Test_Logger_ECS_Format(t *testing.T) {
t.Parallel()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app := fiber.New()
app.Use(New(Config{
Format: ECSFormat,
Stream: buf,
}))
req := httptest.NewRequest(fiber.MethodGet, "/?foo=bar", nil)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
expected := fmt.Sprintf(
"{\"@timestamp\":%q,\"ecs\":{\"version\":\"1.6.0\"},\"client\":{\"ip\":%q},\"http\":{\"request\":{\"method\":%q,\"url\":%q,\"protocol\":%q},\"response\":{\"status_code\":%d,\"body\":{\"bytes\":%d}}},\"log\":{\"level\":\"INFO\",\"logger\":\"fiber\"},\"message\":%q}\n",
time.Now().Format("15:04:05"),
"0.0.0.0",
fiber.MethodGet,
"/?foo=bar",
"HTTP/1.1",
fiber.StatusNotFound,
0,
fmt.Sprintf("%s %s responded with %d", fiber.MethodGet, "/?foo=bar", fiber.StatusNotFound),
)
logResponse := buf.String()
require.Equal(t, expected, logResponse)
}
func getLatencyTimeUnits() []struct {
unit string
div time.Duration
@ -358,7 +617,7 @@ func Test_Logger_WithLatency(t *testing.T) {
app := fiber.New()
logger := New(Config{
Output: buff,
Stream: buff,
Format: "${latency}",
})
app.Use(logger)
@ -403,7 +662,7 @@ func Test_Logger_WithLatency_DefaultFormat(t *testing.T) {
app := fiber.New()
logger := New(Config{
Output: buff,
Stream: buff,
})
app.Use(logger)
@ -453,7 +712,7 @@ func Test_Query_Params(t *testing.T) {
app.Use(New(Config{
Format: "${queryParams}",
Output: buf,
Stream: buf,
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?foo=bar&baz=moz", nil))
@ -474,7 +733,7 @@ func Test_Response_Body(t *testing.T) {
app.Use(New(Config{
Format: "${resBody}",
Output: buf,
Stream: buf,
}))
app.Get("/", func(c fiber.Ctx) error {
@ -508,7 +767,7 @@ func Test_Request_Body(t *testing.T) {
app.Use(New(Config{
Format: "${bytesReceived} ${bytesSent} ${status}",
Output: buf,
Stream: buf,
}))
app.Post("/", func(c fiber.Ctx) error {
@ -536,7 +795,7 @@ func Test_Logger_AppendUint(t *testing.T) {
app.Use(New(Config{
Format: "${bytesReceived} ${bytesSent} ${status}",
Output: buf,
Stream: buf,
}))
app.Get("/", func(c fiber.Ctx) error {
@ -611,7 +870,7 @@ func Test_Response_Header(t *testing.T) {
}))
app.Use(New(Config{
Format: "${respHeader:X-Request-ID}",
Output: buf,
Stream: buf,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello fiber!")
@ -634,7 +893,7 @@ func Test_Req_Header(t *testing.T) {
app.Use(New(Config{
Format: "${reqHeader:test}",
Output: buf,
Stream: buf,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello fiber!")
@ -658,7 +917,7 @@ func Test_ReqHeader_Header(t *testing.T) {
app.Use(New(Config{
Format: "${reqHeader:test}",
Output: buf,
Stream: buf,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello fiber!")
@ -689,7 +948,7 @@ func Test_CustomTags(t *testing.T) {
return output.WriteString(customTag)
},
},
Output: buf,
Stream: buf,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello fiber!")
@ -713,7 +972,7 @@ func Test_Logger_ByteSent_Streaming(t *testing.T) {
app.Use(New(Config{
Format: "${bytesReceived} ${bytesSent} ${status}",
Output: buf,
Stream: buf,
}))
app.Get("/", func(c fiber.Ctx) error {
@ -724,7 +983,7 @@ func Test_Logger_ByteSent_Streaming(t *testing.T) {
for {
i++
msg := fmt.Sprintf("%d - the time is %v", i, time.Now())
fmt.Fprintf(w, "data: Message: %s\n\n", msg) //nolint:errcheck // ignore error
fmt.Fprintf(w, "data: Message: %s\n\n", msg)
err := w.Flush()
if err != nil {
break
@ -759,7 +1018,7 @@ func Test_Logger_EnableColors(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
Output: o,
Stream: o,
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
@ -782,7 +1041,7 @@ func Benchmark_Logger(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
Format: "${bytesReceived} ${bytesSent} ${status}",
Output: io.Discard,
Stream: io.Discard,
}))
app.Get("/", func(c fiber.Ctx) error {
c.Set("test", "test")
@ -794,7 +1053,7 @@ func Benchmark_Logger(b *testing.B) {
b.Run("DefaultFormat", func(bb *testing.B) {
app := fiber.New()
app.Use(New(Config{
Output: io.Discard,
Stream: io.Discard,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
@ -805,7 +1064,7 @@ func Benchmark_Logger(b *testing.B) {
b.Run("DefaultFormatDisableColors", func(bb *testing.B) {
app := fiber.New()
app.Use(New(Config{
Output: io.Discard,
Stream: io.Discard,
DisableColors: true,
}))
app.Get("/", func(c fiber.Ctx) error {
@ -819,7 +1078,7 @@ func Benchmark_Logger(b *testing.B) {
logger := fiberlog.DefaultLogger()
logger.SetOutput(io.Discard)
app.Use(New(Config{
Output: LoggerToWriter(logger, fiberlog.LevelDebug),
Stream: LoggerToWriter(logger, fiberlog.LevelDebug),
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
@ -831,7 +1090,7 @@ func Benchmark_Logger(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
Format: "${bytesReceived} ${bytesSent} ${status} ${reqHeader:test}",
Output: io.Discard,
Stream: io.Discard,
}))
app.Get("/", func(c fiber.Ctx) error {
c.Set("test", "test")
@ -844,7 +1103,7 @@ func Benchmark_Logger(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
Format: "${locals:demo}",
Output: io.Discard,
Stream: io.Discard,
}))
app.Get("/", func(c fiber.Ctx) error {
c.Locals("demo", "johndoe")
@ -857,7 +1116,7 @@ func Benchmark_Logger(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
Format: "${locals:demo}",
Output: io.Discard,
Stream: io.Discard,
}))
app.Get("/int", func(c fiber.Ctx) error {
c.Locals("demo", 55)
@ -874,7 +1133,7 @@ func Benchmark_Logger(b *testing.B) {
io.Discard.Write(logString) //nolint:errcheck // ignore error
}
},
Output: io.Discard,
Stream: io.Discard,
}))
app.Get("/logging", func(ctx fiber.Ctx) error {
return ctx.SendStatus(fiber.StatusOK)
@ -886,7 +1145,7 @@ func Benchmark_Logger(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
Format: "${pid}${reqHeaders}${referer}${scheme}${protocol}${ip}${ips}${host}${url}${ua}${body}${route}${black}${red}${green}${yellow}${blue}${magenta}${cyan}${white}${reset}${error}${reqHeader:test}${query:test}${form:test}${cookie:test}${non}",
Output: io.Discard,
Stream: io.Discard,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
@ -898,7 +1157,7 @@ func Benchmark_Logger(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
Format: "${bytesReceived} ${bytesSent} ${status}",
Output: io.Discard,
Stream: io.Discard,
}))
app.Get("/", func(c fiber.Ctx) error {
c.Set("Connection", "keep-alive")
@ -908,7 +1167,7 @@ func Benchmark_Logger(b *testing.B) {
for {
i++
msg := fmt.Sprintf("%d - the time is %v", i, time.Now())
fmt.Fprintf(w, "data: Message: %s\n\n", msg) //nolint:errcheck // ignore error
fmt.Fprintf(w, "data: Message: %s\n\n", msg)
err := w.Flush()
if err != nil {
break
@ -927,7 +1186,7 @@ func Benchmark_Logger(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
Format: "${resBody}",
Output: io.Discard,
Stream: io.Discard,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Sample response body")
@ -950,7 +1209,7 @@ func Benchmark_Logger_Parallel(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
Format: "${bytesReceived} ${bytesSent} ${status}",
Output: io.Discard,
Stream: io.Discard,
}))
app.Get("/", func(c fiber.Ctx) error {
c.Set("test", "test")
@ -962,7 +1221,7 @@ func Benchmark_Logger_Parallel(b *testing.B) {
b.Run("DefaultFormat", func(bb *testing.B) {
app := fiber.New()
app.Use(New(Config{
Output: io.Discard,
Stream: io.Discard,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
@ -975,7 +1234,7 @@ func Benchmark_Logger_Parallel(b *testing.B) {
logger := fiberlog.DefaultLogger()
logger.SetOutput(io.Discard)
app.Use(New(Config{
Output: LoggerToWriter(logger, fiberlog.LevelDebug),
Stream: LoggerToWriter(logger, fiberlog.LevelDebug),
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
@ -986,7 +1245,7 @@ func Benchmark_Logger_Parallel(b *testing.B) {
b.Run("DefaultFormatDisableColors", func(bb *testing.B) {
app := fiber.New()
app.Use(New(Config{
Output: io.Discard,
Stream: io.Discard,
DisableColors: true,
}))
app.Get("/", func(c fiber.Ctx) error {
@ -999,7 +1258,7 @@ func Benchmark_Logger_Parallel(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
Format: "${bytesReceived} ${bytesSent} ${status} ${reqHeader:test}",
Output: io.Discard,
Stream: io.Discard,
}))
app.Get("/", func(c fiber.Ctx) error {
c.Set("test", "test")
@ -1012,7 +1271,7 @@ func Benchmark_Logger_Parallel(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
Format: "${locals:demo}",
Output: io.Discard,
Stream: io.Discard,
}))
app.Get("/", func(c fiber.Ctx) error {
c.Locals("demo", "johndoe")
@ -1025,7 +1284,7 @@ func Benchmark_Logger_Parallel(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
Format: "${locals:demo}",
Output: io.Discard,
Stream: io.Discard,
}))
app.Get("/int", func(c fiber.Ctx) error {
c.Locals("demo", 55)
@ -1042,7 +1301,7 @@ func Benchmark_Logger_Parallel(b *testing.B) {
io.Discard.Write(logString) //nolint:errcheck // ignore error
}
},
Output: io.Discard,
Stream: io.Discard,
}))
app.Get("/logging", func(ctx fiber.Ctx) error {
return ctx.SendStatus(fiber.StatusOK)
@ -1054,7 +1313,7 @@ func Benchmark_Logger_Parallel(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
Format: "${pid}${reqHeaders}${referer}${scheme}${protocol}${ip}${ips}${host}${url}${ua}${body}${route}${black}${red}${green}${yellow}${blue}${magenta}${cyan}${white}${reset}${error}${reqHeader:test}${query:test}${form:test}${cookie:test}${non}",
Output: io.Discard,
Stream: io.Discard,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
@ -1066,7 +1325,7 @@ func Benchmark_Logger_Parallel(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
Format: "${bytesReceived} ${bytesSent} ${status}",
Output: io.Discard,
Stream: io.Discard,
}))
app.Get("/", func(c fiber.Ctx) error {
c.Set("Connection", "keep-alive")
@ -1076,7 +1335,7 @@ func Benchmark_Logger_Parallel(b *testing.B) {
for {
i++
msg := fmt.Sprintf("%d - the time is %v", i, time.Now())
fmt.Fprintf(w, "data: Message: %s\n\n", msg) //nolint:errcheck // ignore error
fmt.Fprintf(w, "data: Message: %s\n\n", msg)
err := w.Flush()
if err != nil {
break
@ -1095,7 +1354,7 @@ func Benchmark_Logger_Parallel(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
Format: "${resBody}",
Output: io.Discard,
Stream: io.Discard,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Sample response body")

View File

@ -506,7 +506,10 @@ func Test_Proxy_Do_WithRealURL(t *testing.T) {
return Do(c, "https://www.google.com")
})
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil), fiber.TestConfig{
Timeout: 2 * time.Second,
FailOnTimeout: true,
})
require.NoError(t, err1)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, "/test", resp.Request.URL.String())
@ -523,7 +526,10 @@ func Test_Proxy_Do_WithRedirect(t *testing.T) {
return Do(c, "https://google.com")
})
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil), fiber.TestConfig{
Timeout: 2 * time.Second,
FailOnTimeout: true,
})
require.NoError(t, err1)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
@ -558,7 +564,10 @@ func Test_Proxy_DoRedirects_TooManyRedirects(t *testing.T) {
return DoRedirects(c, "http://google.com", 0)
})
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil), fiber.TestConfig{
Timeout: 2 * time.Second,
FailOnTimeout: true,
})
require.NoError(t, err1)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)

View File

@ -66,6 +66,7 @@ func New(root string, cfg ...Config) fiber.Handler {
AcceptByteRange: config.ByteRange,
Compress: config.Compress,
CompressBrotli: config.Compress, // Brotli compression won't work without this
CompressZstd: config.Compress, // Zstd compression won't work without this
CompressedFileSuffixes: c.App().Config().CompressedFileSuffixes,
CacheDuration: config.CacheDuration,
SkipCache: config.CacheDuration < 0,

View File

@ -55,7 +55,7 @@ func (app *App) mount(prefix string, subApp *App) Router {
// register mounted group
mountGroup := &Group{Prefix: prefix, app: subApp}
app.register([]string{methodUse}, prefix, mountGroup, nil)
app.register([]string{methodUse}, prefix, mountGroup)
// Execute onMount hooks
if err := subApp.hooks.executeOnMountHooks(app); err != nil {
@ -85,7 +85,7 @@ func (grp *Group) mount(prefix string, subApp *App) Router {
// register mounted group
mountGroup := &Group{Prefix: groupPath, app: subApp}
grp.app.register([]string{methodUse}, groupPath, mountGroup, nil)
grp.app.register([]string{methodUse}, groupPath, mountGroup)
// Execute onMount hooks
if err := subApp.hooks.executeOnMountHooks(grp.app); err != nil {

212
path.go
View File

@ -7,9 +7,11 @@
package fiber
import (
"bytes"
"regexp"
"strconv"
"strings"
"sync"
"time"
"unicode"
@ -25,6 +27,12 @@ type routeParser struct {
plusCount int // number of plus parameters, used internally to give the plus parameter its number
}
var routerParserPool = &sync.Pool{
New: func() any {
return &routeParser{}
},
}
// routeSegment holds the segment metadata
type routeSegment struct {
// const information
@ -115,18 +123,6 @@ var (
parameterDelimiterChars = append([]byte{paramStarterChar, escapeChar}, routeDelimiter...)
// list of chars to find the end of a parameter
parameterEndChars = append([]byte{optionalParam}, parameterDelimiterChars...)
// list of parameter constraint start
parameterConstraintStartChars = []byte{paramConstraintStart}
// list of parameter constraint end
parameterConstraintEndChars = []byte{paramConstraintEnd}
// list of parameter separator
parameterConstraintSeparatorChars = []byte{paramConstraintSeparator}
// list of parameter constraint data start
parameterConstraintDataStartChars = []byte{paramConstraintDataStart}
// list of parameter constraint data end
parameterConstraintDataEndChars = []byte{paramConstraintDataEnd}
// list of parameter constraint data separator
parameterConstraintDataSeparatorChars = []byte{paramConstraintDataSeparator}
)
// RoutePatternMatch checks if a given path matches a Fiber route pattern.
@ -152,11 +148,11 @@ func RoutePatternMatch(path, pattern string, cfg ...Config) bool {
pattern = "/" + pattern
}
patternPretty := pattern
patternPretty := []byte(pattern)
// Case-sensitive routing, all to lowercase
if !config.CaseSensitive {
patternPretty = utils.ToLower(patternPretty)
patternPretty = utils.ToLowerBytes(patternPretty)
path = utils.ToLower(path)
}
// Strict routing, remove trailing slashes
@ -164,12 +160,15 @@ func RoutePatternMatch(path, pattern string, cfg ...Config) bool {
patternPretty = utils.TrimRight(patternPretty, '/')
}
parser := parseRoute(patternPretty)
parser, _ := routerParserPool.Get().(*routeParser) //nolint:errcheck // only contains routeParser
parser.reset()
parser.parseRoute(string(patternPretty))
defer routerParserPool.Put(parser)
if patternPretty == "/" && path == "/" {
if string(patternPretty) == "/" && path == "/" {
return true
// '*' wildcard matches any path
} else if patternPretty == "/*" {
} else if string(patternPretty) == "/*" {
return true
}
@ -180,42 +179,47 @@ func RoutePatternMatch(path, pattern string, cfg ...Config) bool {
}
}
// Check for a simple match
patternPretty = RemoveEscapeChar(patternPretty)
if len(patternPretty) == len(path) && patternPretty == path {
return true
}
// No match
return false
patternPretty = RemoveEscapeCharBytes(patternPretty)
return string(patternPretty) == path
}
func (parser *routeParser) reset() {
parser.segs = parser.segs[:0]
parser.params = parser.params[:0]
parser.wildCardCount = 0
parser.plusCount = 0
}
// parseRoute analyzes the route and divides it into segments for constant areas and parameters,
// this information is needed later when assigning the requests to the declared routes
func parseRoute(pattern string, customConstraints ...CustomConstraint) routeParser {
parser := routeParser{}
part := ""
func (parser *routeParser) parseRoute(pattern string, customConstraints ...CustomConstraint) {
var n int
var seg *routeSegment
for len(pattern) > 0 {
nextParamPosition := findNextParamPosition(pattern)
// handle the parameter part
if nextParamPosition == 0 {
processedPart, seg := parser.analyseParameterPart(pattern, customConstraints...)
parser.params, parser.segs, part = append(parser.params, seg.ParamName), append(parser.segs, seg), processedPart
n, seg = parser.analyseParameterPart(pattern, customConstraints...)
parser.params, parser.segs = append(parser.params, seg.ParamName), append(parser.segs, seg)
} else {
processedPart, seg := parser.analyseConstantPart(pattern, nextParamPosition)
parser.segs, part = append(parser.segs, seg), processedPart
n, seg = parser.analyseConstantPart(pattern, nextParamPosition)
parser.segs = append(parser.segs, seg)
}
// reduce the pattern by the processed parts
if len(part) == len(pattern) {
break
}
pattern = pattern[len(part):]
pattern = pattern[n:]
}
// mark last segment
if len(parser.segs) > 0 {
parser.segs[len(parser.segs)-1].IsLast = true
}
parser.segs = addParameterMetaInfo(parser.segs)
}
// parseRoute analyzes the route and divides it into segments for constant areas and parameters,
// this information is needed later when assigning the requests to the declared routes
func parseRoute(pattern string, customConstraints ...CustomConstraint) routeParser {
parser := routeParser{}
parser.parseRoute(pattern, customConstraints...)
return parser
}
@ -283,7 +287,7 @@ func findNextParamPosition(pattern string) int {
}
// analyseConstantPart find the end of the constant part and create the route segment
func (*routeParser) analyseConstantPart(pattern string, nextParamPosition int) (string, *routeSegment) {
func (*routeParser) analyseConstantPart(pattern string, nextParamPosition int) (int, *routeSegment) {
// handle the constant part
processedPart := pattern
if nextParamPosition != -1 {
@ -291,14 +295,14 @@ func (*routeParser) analyseConstantPart(pattern string, nextParamPosition int) (
processedPart = pattern[:nextParamPosition]
}
constPart := RemoveEscapeChar(processedPart)
return processedPart, &routeSegment{
return len(processedPart), &routeSegment{
Const: constPart,
Length: len(constPart),
}
}
// analyseParameterPart find the parameter end and create the route segment
func (routeParser *routeParser) analyseParameterPart(pattern string, customConstraints ...CustomConstraint) (string, *routeSegment) {
func (parser *routeParser) analyseParameterPart(pattern string, customConstraints ...CustomConstraint) (int, *routeSegment) {
isWildCard := pattern[0] == wildcardParam
isPlusParam := pattern[0] == plusParam
@ -317,18 +321,19 @@ func (routeParser *routeParser) analyseParameterPart(pattern string, customConst
parameterEndPosition = 0
case parameterEndPosition == -1:
parameterEndPosition = len(pattern) - 1
case !isInCharset(pattern[parameterEndPosition+1], parameterDelimiterChars):
case bytes.IndexByte(parameterDelimiterChars, pattern[parameterEndPosition+1]) == -1:
parameterEndPosition++
}
// find constraint part if exists in the parameter part and remove it
if parameterEndPosition > 0 {
parameterConstraintStart = findNextNonEscapedCharsetPosition(pattern[0:parameterEndPosition], parameterConstraintStartChars)
parameterConstraintEnd = findLastCharsetPosition(pattern[0:parameterEndPosition+1], parameterConstraintEndChars)
parameterConstraintStart = findNextNonEscapedCharPosition(pattern[:parameterEndPosition], paramConstraintStart)
parameterConstraintEnd = strings.LastIndexByte(pattern[:parameterEndPosition+1], paramConstraintEnd)
}
// cut params part
processedPart := pattern[0 : parameterEndPosition+1]
n := parameterEndPosition + 1
paramName := RemoveEscapeChar(GetTrimmedParam(processedPart))
// Check has constraint
@ -336,12 +341,12 @@ func (routeParser *routeParser) analyseParameterPart(pattern string, customConst
if hasConstraint := parameterConstraintStart != -1 && parameterConstraintEnd != -1; hasConstraint {
constraintString := pattern[parameterConstraintStart+1 : parameterConstraintEnd]
userConstraints := splitNonEscaped(constraintString, string(parameterConstraintSeparatorChars))
userConstraints := splitNonEscaped(constraintString, paramConstraintSeparator)
constraints = make([]*Constraint, 0, len(userConstraints))
for _, c := range userConstraints {
start := findNextNonEscapedCharsetPosition(c, parameterConstraintDataStartChars)
end := findLastCharsetPosition(c, parameterConstraintDataEndChars)
start := findNextNonEscapedCharPosition(c, paramConstraintDataStart)
end := strings.LastIndexByte(c, paramConstraintDataEnd)
// Assign constraint
if start != -1 && end != -1 {
@ -353,7 +358,7 @@ func (routeParser *routeParser) analyseParameterPart(pattern string, customConst
// remove escapes from data
if constraint.ID != regexConstraint {
constraint.Data = splitNonEscaped(c[start+1:end], string(parameterConstraintDataSeparatorChars))
constraint.Data = splitNonEscaped(c[start+1:end], paramConstraintDataSeparator)
if len(constraint.Data) == 1 {
constraint.Data[0] = RemoveEscapeChar(constraint.Data[0])
} else if len(constraint.Data) == 2 { // This is fine, we simply expect two parts
@ -384,11 +389,11 @@ func (routeParser *routeParser) analyseParameterPart(pattern string, customConst
// add access iterator to wildcard and plus
if isWildCard {
routeParser.wildCardCount++
paramName += strconv.Itoa(routeParser.wildCardCount)
parser.wildCardCount++
paramName += strconv.Itoa(parser.wildCardCount)
} else if isPlusParam {
routeParser.plusCount++
paramName += strconv.Itoa(routeParser.plusCount)
parser.plusCount++
paramName += strconv.Itoa(parser.plusCount)
}
segment := &routeSegment{
@ -402,17 +407,7 @@ func (routeParser *routeParser) analyseParameterPart(pattern string, customConst
segment.Constraints = constraints
}
return processedPart, segment
}
// isInCharset check is the given character in the charset list
func isInCharset(searchChar byte, charset []byte) bool {
for _, char := range charset {
if char == searchChar {
return true
}
}
return false
return n, segment
}
// findNextCharsetPosition search the next char position from the charset
@ -427,23 +422,11 @@ func findNextCharsetPosition(search string, charset []byte) int {
return nextPosition
}
// findLastCharsetPosition search the last char position from the charset
func findLastCharsetPosition(search string, charset []byte) int {
lastPosition := -1
for _, char := range charset {
if pos := strings.LastIndexByte(search, char); pos != -1 && (pos < lastPosition || lastPosition == -1) {
lastPosition = pos
}
}
return lastPosition
}
// findNextCharsetPositionConstraint search the next char position from the charset
// findNextCharsetPositionConstraint searches the next char position from the charset
// unlike findNextCharsetPosition, it takes care of constraint start-end chars to parse route pattern
func findNextCharsetPositionConstraint(search string, charset []byte) int {
constraintStart := findNextNonEscapedCharsetPosition(search, parameterConstraintStartChars)
constraintEnd := findNextNonEscapedCharsetPosition(search, parameterConstraintEndChars)
constraintStart := findNextNonEscapedCharPosition(search, paramConstraintStart)
constraintEnd := findNextNonEscapedCharPosition(search, paramConstraintEnd)
nextPosition := -1
for _, char := range charset {
@ -459,7 +442,7 @@ func findNextCharsetPositionConstraint(search string, charset []byte) int {
return nextPosition
}
// findNextNonEscapedCharsetPosition search the next char position from the charset and skip the escaped characters
// findNextNonEscapedCharsetPosition searches the next char position from the charset and skips the escaped characters
func findNextNonEscapedCharsetPosition(search string, charset []byte) int {
pos := findNextCharsetPosition(search, charset)
for pos > 0 && search[pos-1] == escapeChar {
@ -478,25 +461,35 @@ func findNextNonEscapedCharsetPosition(search string, charset []byte) int {
return pos
}
// findNextNonEscapedCharPosition searches the next char position and skips the escaped characters
func findNextNonEscapedCharPosition(search string, char byte) int {
for i := 0; i < len(search); i++ {
if search[i] == char && (i == 0 || search[i-1] != escapeChar) {
return i
}
}
return -1
}
// splitNonEscaped slices s into all substrings separated by sep and returns a slice of the substrings between those separators
// This function also takes a care of escape char when splitting.
func splitNonEscaped(s, sep string) []string {
func splitNonEscaped(s string, sep byte) []string {
var result []string
i := findNextNonEscapedCharsetPosition(s, []byte(sep))
i := findNextNonEscapedCharPosition(s, sep)
for i > -1 {
result = append(result, s[:i])
s = s[i+len(sep):]
i = findNextNonEscapedCharsetPosition(s, []byte(sep))
s = s[i+1:]
i = findNextNonEscapedCharPosition(s, sep)
}
return append(result, s)
}
// getMatch parses the passed url and tries to match it against the route segments and determine the parameter positions
func (routeParser *routeParser) getMatch(detectionPath, path string, params *[maxParams]string, partialCheck bool) bool { //nolint: revive // Accepting a bool param is fine here
func (parser *routeParser) getMatch(detectionPath, path string, params *[maxParams]string, partialCheck bool) bool { //nolint:revive // Accepting a bool param is fine here
var i, paramsIterator, partLen int
for _, segment := range routeParser.segs {
for _, segment := range parser.segs {
partLen = len(detectionPath)
// check const segment
if !segment.IsParam {
@ -618,7 +611,7 @@ func GetTrimmedParam(param string) string {
return param[start:end]
}
// RemoveEscapeChar remove escape characters
// RemoveEscapeChar removes escape characters
func RemoveEscapeChar(word string) string {
b := []byte(word)
dst := 0
@ -632,6 +625,18 @@ func RemoveEscapeChar(word string) string {
return string(b[:dst])
}
// RemoveEscapeCharBytes removes escape characters
func RemoveEscapeCharBytes(word []byte) []byte {
dst := 0
for src := 0; src < len(word); src++ {
if word[src] != '\\' {
word[dst] = word[src]
dst++
}
}
return word[:dst]
}
func getParamConstraintType(constraintPart string) TypeConstraint {
switch constraintPart {
case ConstraintInt:
@ -667,12 +672,25 @@ func getParamConstraintType(constraintPart string) TypeConstraint {
}
}
//nolint:errcheck // TODO: Properly check _all_ errors in here, log them & immediately return
// CheckConstraint validates if a param matches the given constraint
// Returns true if the param passes the constraint check, false otherwise
//
//nolint:errcheck // TODO: Properly check _all_ errors in here, log them or immediately return
func (c *Constraint) CheckConstraint(param string) bool {
var err error
var num int
// First check if there's a custom constraint with the same name
// This allows custom constraints to override built-in constraints
for _, cc := range c.customConstraints {
if cc.Name() == c.Name {
return cc.Execute(param, c.Data...)
}
}
// check data exists
var (
err error
num int
)
// Validate constraint has required data
needOneData := []TypeConstraint{minLenConstraint, maxLenConstraint, lenConstraint, minConstraint, maxConstraint, datetimeConstraint, regexConstraint}
needTwoData := []TypeConstraint{betweenLenConstraint, rangeConstraint}
@ -691,11 +709,7 @@ func (c *Constraint) CheckConstraint(param string) bool {
// check constraints
switch c.ID {
case noConstraint:
for _, cc := range c.customConstraints {
if cc.Name() == c.Name {
return cc.Execute(param, c.Data...)
}
}
return true
case intConstraint:
_, err = strconv.Atoi(param)
case boolConstraint:
@ -739,14 +753,14 @@ func (c *Constraint) CheckConstraint(param string) bool {
data, _ := strconv.Atoi(c.Data[0])
num, err = strconv.Atoi(param)
if num < data {
if err != nil || num < data {
return false
}
case maxConstraint:
data, _ := strconv.Atoi(c.Data[0])
num, err = strconv.Atoi(param)
if num > data {
if err != nil || num > data {
return false
}
case rangeConstraint:
@ -754,12 +768,18 @@ func (c *Constraint) CheckConstraint(param string) bool {
data2, _ := strconv.Atoi(c.Data[1])
num, err = strconv.Atoi(param)
if num < data || num > data2 {
if err != nil || num < data || num > data2 {
return false
}
case datetimeConstraint:
_, err = time.Parse(c.Data[0], param)
if err != nil {
return false
}
case regexConstraint:
if c.RegexCompiler == nil {
return false
}
if match := c.RegexCompiler.MatchString(param); !match {
return false
}

View File

@ -217,7 +217,7 @@ func Benchmark_Path_matchParams(t *testing.B) {
state = "not match"
}
t.Run(testCollection.pattern+" | "+state+" | "+c.url, func(b *testing.B) {
for i := 0; i <= b.N; i++ {
for i := 0; i < b.N; i++ {
if match := parser.getMatch(c.url, c.url, &ctxParams, c.partialCheck); match {
// Get testCases from the original path
matchRes = true
@ -250,7 +250,7 @@ func Benchmark_RoutePatternMatch(t *testing.B) {
state = "not match"
}
t.Run(testCollection.pattern+" | "+state+" | "+c.url, func(b *testing.B) {
for i := 0; i <= b.N; i++ {
for i := 0; i < b.N; i++ {
if match := RoutePatternMatch(c.url, testCollection.pattern); match {
// Get testCases from the original path
matchRes = true

View File

@ -713,6 +713,14 @@ func init() {
{url: "/api/v1/", params: []string{""}, match: true},
},
},
// Add test case for RegexCompiler == nil
{
pattern: "/api/v1/:param<regex(\\d+)>",
testCases: []routeTestCase{
{url: "/api/v1/123", params: []string{"123"}, match: true},
{url: "/api/v1/abc", params: nil, match: false},
},
},
}...,
)
}

View File

@ -178,7 +178,7 @@ func Test_Redirect_Back_WithFlashMessages(t *testing.T) {
return c.SendString("user")
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
err := c.Redirect().With("success", "1").With("message", "test").Back("/")
require.NoError(t, err)
@ -225,7 +225,7 @@ func Test_Redirect_Route_WithFlashMessages(t *testing.T) {
return c.SendString("user")
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
err := c.Redirect().With("success", "1").With("message", "test").Route("user")
@ -259,7 +259,7 @@ func Test_Redirect_Route_WithOldInput(t *testing.T) {
return c.SendString("user")
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
c.Request().URI().SetQueryString("id=1&name=tom")
err := c.Redirect().With("success", "1").With("message", "test").WithInput().Route("user")
@ -294,7 +294,7 @@ func Test_Redirect_Route_WithOldInput(t *testing.T) {
return c.SendString("user")
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
c.Request().Header.Set(HeaderContentType, MIMEApplicationForm)
c.Request().SetBodyString("id=1&name=tom")
@ -330,7 +330,7 @@ func Test_Redirect_Route_WithOldInput(t *testing.T) {
return c.SendString("user")
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
@ -376,7 +376,7 @@ func Test_Redirect_parseAndClearFlashMessages(t *testing.T) {
return c.SendString("user")
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
msgs := redirectionMsgs{
{
@ -464,7 +464,7 @@ func Benchmark_Redirect_Route(b *testing.B) {
return c.JSON(c.Params("name"))
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
b.ReportAllocs()
b.ResetTimer()
@ -491,7 +491,7 @@ func Benchmark_Redirect_Route_WithQueries(b *testing.B) {
return c.JSON(c.Params("name"))
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
b.ReportAllocs()
b.ResetTimer()
@ -523,7 +523,7 @@ func Benchmark_Redirect_Route_WithFlashMessages(b *testing.B) {
return c.SendString("user")
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
b.ReportAllocs()
b.ResetTimer()
@ -576,7 +576,7 @@ func Benchmark_Redirect_parseAndClearFlashMessages(b *testing.B) {
return c.SendString("user")
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
val, err := testredirectionMsgs.MarshalMsg(nil)
require.NoError(b, err)
@ -618,7 +618,7 @@ func Benchmark_Redirect_processFlashMessages(b *testing.B) {
return c.SendString("user")
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
c.Redirect().With("success", "1").With("message", "test")
@ -647,7 +647,7 @@ func Benchmark_Redirect_Messages(b *testing.B) {
return c.SendString("user")
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
val, err := testredirectionMsgs.MarshalMsg(nil)
require.NoError(b, err)
@ -684,7 +684,7 @@ func Benchmark_Redirect_OldInputs(b *testing.B) {
return c.SendString("user")
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
val, err := testredirectionMsgs.MarshalMsg(nil)
require.NoError(b, err)
@ -719,7 +719,7 @@ func Benchmark_Redirect_Message(b *testing.B) {
return c.SendString("user")
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
val, err := testredirectionMsgs.MarshalMsg(nil)
require.NoError(b, err)
@ -750,7 +750,7 @@ func Benchmark_Redirect_OldInput(b *testing.B) {
return c.SendString("user")
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
val, err := testredirectionMsgs.MarshalMsg(nil)
require.NoError(b, err)

View File

@ -6,18 +6,18 @@ package fiber
// Register defines all router handle interface generate by Route().
type Register interface {
All(handler Handler, middleware ...Handler) Register
Get(handler Handler, middleware ...Handler) Register
Head(handler Handler, middleware ...Handler) Register
Post(handler Handler, middleware ...Handler) Register
Put(handler Handler, middleware ...Handler) Register
Delete(handler Handler, middleware ...Handler) Register
Connect(handler Handler, middleware ...Handler) Register
Options(handler Handler, middleware ...Handler) Register
Trace(handler Handler, middleware ...Handler) Register
Patch(handler Handler, middleware ...Handler) Register
All(handler Handler, handlers ...Handler) Register
Get(handler Handler, handlers ...Handler) Register
Head(handler Handler, handlers ...Handler) Register
Post(handler Handler, handlers ...Handler) Register
Put(handler Handler, handlers ...Handler) Register
Delete(handler Handler, handlers ...Handler) Register
Connect(handler Handler, handlers ...Handler) Register
Options(handler Handler, handlers ...Handler) Register
Trace(handler Handler, handlers ...Handler) Register
Patch(handler Handler, handlers ...Handler) Register
Add(methods []string, handler Handler, middleware ...Handler) Register
Add(methods []string, handler Handler, handlers ...Handler) Register
Route(path string) Register
}
@ -45,68 +45,68 @@ type Registering struct {
// })
//
// This method will match all HTTP verbs: GET, POST, PUT, HEAD etc...
func (r *Registering) All(handler Handler, middleware ...Handler) Register {
r.app.register([]string{methodUse}, r.path, nil, handler, middleware...)
func (r *Registering) All(handler Handler, handlers ...Handler) Register {
r.app.register([]string{methodUse}, r.path, nil, append([]Handler{handler}, handlers...)...)
return r
}
// Get registers a route for GET methods that requests a representation
// of the specified resource. Requests using GET should only retrieve data.
func (r *Registering) Get(handler Handler, middleware ...Handler) Register {
r.app.Add([]string{MethodGet}, r.path, handler, middleware...)
func (r *Registering) Get(handler Handler, handlers ...Handler) Register {
r.app.Add([]string{MethodGet}, r.path, handler, handlers...)
return r
}
// Head registers a route for HEAD methods that asks for a response identical
// to that of a GET request, but without the response body.
func (r *Registering) Head(handler Handler, middleware ...Handler) Register {
return r.Add([]string{MethodHead}, handler, middleware...)
func (r *Registering) Head(handler Handler, handlers ...Handler) Register {
return r.Add([]string{MethodHead}, handler, handlers...)
}
// Post registers a route for POST methods that is used to submit an entity to the
// specified resource, often causing a change in state or side effects on the server.
func (r *Registering) Post(handler Handler, middleware ...Handler) Register {
return r.Add([]string{MethodPost}, handler, middleware...)
func (r *Registering) Post(handler Handler, handlers ...Handler) Register {
return r.Add([]string{MethodPost}, handler, handlers...)
}
// Put registers a route for PUT methods that replaces all current representations
// of the target resource with the request payload.
func (r *Registering) Put(handler Handler, middleware ...Handler) Register {
return r.Add([]string{MethodPut}, handler, middleware...)
func (r *Registering) Put(handler Handler, handlers ...Handler) Register {
return r.Add([]string{MethodPut}, handler, handlers...)
}
// Delete registers a route for DELETE methods that deletes the specified resource.
func (r *Registering) Delete(handler Handler, middleware ...Handler) Register {
return r.Add([]string{MethodDelete}, handler, middleware...)
func (r *Registering) Delete(handler Handler, handlers ...Handler) Register {
return r.Add([]string{MethodDelete}, handler, handlers...)
}
// Connect registers a route for CONNECT methods that establishes a tunnel to the
// server identified by the target resource.
func (r *Registering) Connect(handler Handler, middleware ...Handler) Register {
return r.Add([]string{MethodConnect}, handler, middleware...)
func (r *Registering) Connect(handler Handler, handlers ...Handler) Register {
return r.Add([]string{MethodConnect}, handler, handlers...)
}
// Options registers a route for OPTIONS methods that is used to describe the
// communication options for the target resource.
func (r *Registering) Options(handler Handler, middleware ...Handler) Register {
return r.Add([]string{MethodOptions}, handler, middleware...)
func (r *Registering) Options(handler Handler, handlers ...Handler) Register {
return r.Add([]string{MethodOptions}, handler, handlers...)
}
// Trace registers a route for TRACE methods that performs a message loop-back
// test along the r.Path to the target resource.
func (r *Registering) Trace(handler Handler, middleware ...Handler) Register {
return r.Add([]string{MethodTrace}, handler, middleware...)
func (r *Registering) Trace(handler Handler, handlers ...Handler) Register {
return r.Add([]string{MethodTrace}, handler, handlers...)
}
// Patch registers a route for PATCH methods that is used to apply partial
// modifications to a resource.
func (r *Registering) Patch(handler Handler, middleware ...Handler) Register {
return r.Add([]string{MethodPatch}, handler, middleware...)
func (r *Registering) Patch(handler Handler, handlers ...Handler) Register {
return r.Add([]string{MethodPatch}, handler, handlers...)
}
// Add allows you to specify multiple HTTP methods to register a route.
func (r *Registering) Add(methods []string, handler Handler, middleware ...Handler) Register {
r.app.register(methods, r.path, nil, handler, middleware...)
func (r *Registering) Add(methods []string, handler Handler, handlers ...Handler) Register {
r.app.register(methods, r.path, nil, append([]Handler{handler}, handlers...)...)
return r
}

159
req.go Normal file
View File

@ -0,0 +1,159 @@
package fiber
import (
"crypto/tls"
"mime/multipart"
)
//go:generate ifacemaker --file req.go --struct DefaultReq --iface Req --pkg fiber --output req_interface_gen.go --not-exported true --iface-comment "Req"
type DefaultReq struct {
ctx *DefaultCtx
}
func (r *DefaultReq) Accepts(offers ...string) string {
return r.ctx.Accepts(offers...)
}
func (r *DefaultReq) AcceptsCharsets(offers ...string) string {
return r.ctx.AcceptsCharsets(offers...)
}
func (r *DefaultReq) AcceptsEncodings(offers ...string) string {
return r.ctx.AcceptsEncodings(offers...)
}
func (r *DefaultReq) AcceptsLanguages(offers ...string) string {
return r.ctx.AcceptsLanguages(offers...)
}
func (r *DefaultReq) BaseURL() string {
return r.ctx.BaseURL()
}
func (r *DefaultReq) Body() []byte {
return r.ctx.Body()
}
func (r *DefaultReq) BodyRaw() []byte {
return r.ctx.BodyRaw()
}
func (r *DefaultReq) ClientHelloInfo() *tls.ClientHelloInfo {
return r.ctx.ClientHelloInfo()
}
func (r *DefaultReq) Cookies(key string, defaultValue ...string) string {
return r.ctx.Cookies(key, defaultValue...)
}
func (r *DefaultReq) FormFile(key string) (*multipart.FileHeader, error) {
return r.ctx.FormFile(key)
}
func (r *DefaultReq) FormValue(key string, defaultValue ...string) string {
return r.ctx.FormValue(key, defaultValue...)
}
func (r *DefaultReq) Fresh() bool {
return r.ctx.Fresh()
}
func (r *DefaultReq) Get(key string, defaultValue ...string) string {
return r.ctx.Get(key, defaultValue...)
}
func (r *DefaultReq) Host() string {
return r.ctx.Host()
}
func (r *DefaultReq) Hostname() string {
return r.ctx.Hostname()
}
func (r *DefaultReq) IP() string {
return r.ctx.IP()
}
func (r *DefaultReq) IPs() []string {
return r.ctx.IPs()
}
func (r *DefaultReq) Is(extension string) bool {
return r.ctx.Is(extension)
}
func (r *DefaultReq) IsFromLocal() bool {
return r.ctx.IsFromLocal()
}
func (r *DefaultReq) IsProxyTrusted() bool {
return r.ctx.IsProxyTrusted()
}
func (r *DefaultReq) Method(override ...string) string {
return r.ctx.Method(override...)
}
func (r *DefaultReq) MultipartForm() (*multipart.Form, error) {
return r.ctx.MultipartForm()
}
func (r *DefaultReq) OriginalURL() string {
return r.ctx.OriginalURL()
}
func (r *DefaultReq) Params(key string, defaultValue ...string) string {
return r.ctx.Params(key, defaultValue...)
}
func (r *DefaultReq) Path(override ...string) string {
return r.ctx.Path(override...)
}
func (r *DefaultReq) Port() string {
return r.ctx.Port()
}
func (r *DefaultReq) Protocol() string {
return r.ctx.Protocol()
}
func (r *DefaultReq) Queries() map[string]string {
return r.ctx.Queries()
}
func (r *DefaultReq) Query(key string, defaultValue ...string) string {
return r.ctx.Query(key, defaultValue...)
}
func (r *DefaultReq) Range(size int) (Range, error) {
return r.ctx.Range(size)
}
func (r *DefaultReq) Route() *Route {
return r.ctx.Route()
}
func (r *DefaultReq) SaveFile(fileheader *multipart.FileHeader, path string) error {
return r.ctx.SaveFile(fileheader, path)
}
func (r *DefaultReq) SaveFileToStorage(fileheader *multipart.FileHeader, path string, storage Storage) error {
return r.ctx.SaveFileToStorage(fileheader, path, storage)
}
func (r *DefaultReq) Secure() bool {
return r.ctx.Secure()
}
func (r *DefaultReq) Stale() bool {
return r.ctx.Stale()
}
func (r *DefaultReq) Subdomains(offset ...int) []string {
return r.ctx.Subdomains(offset...)
}
func (r *DefaultReq) XHR() bool {
return r.ctx.XHR()
}

49
req_interface_gen.go Normal file
View File

@ -0,0 +1,49 @@
// Code generated by ifacemaker; DO NOT EDIT.
package fiber
import (
"crypto/tls"
"mime/multipart"
)
// Req
type Req interface {
Accepts(offers ...string) string
AcceptsCharsets(offers ...string) string
AcceptsEncodings(offers ...string) string
AcceptsLanguages(offers ...string) string
BaseURL() string
Body() []byte
BodyRaw() []byte
ClientHelloInfo() *tls.ClientHelloInfo
Cookies(key string, defaultValue ...string) string
FormFile(key string) (*multipart.FileHeader, error)
FormValue(key string, defaultValue ...string) string
Fresh() bool
Get(key string, defaultValue ...string) string
Host() string
Hostname() string
IP() string
IPs() []string
Is(extension string) bool
IsFromLocal() bool
IsProxyTrusted() bool
Method(override ...string) string
MultipartForm() (*multipart.Form, error)
OriginalURL() string
Params(key string, defaultValue ...string) string
Path(override ...string) string
Port() string
Protocol() string
Queries() map[string]string
Query(key string, defaultValue ...string) string
Range(size int) (Range, error)
Route() *Route
SaveFile(fileheader *multipart.FileHeader, path string) error
SaveFileToStorage(fileheader *multipart.FileHeader, path string, storage Storage) error
Secure() bool
Stale() bool
Subdomains(offset ...int) []string
XHR() bool
}

118
res.go Normal file
View File

@ -0,0 +1,118 @@
package fiber
import (
"bufio"
)
//go:generate ifacemaker --file res.go --struct DefaultRes --iface Res --pkg fiber --output res_interface_gen.go --not-exported true --iface-comment "Res"
type DefaultRes struct {
ctx *DefaultCtx
}
func (r *DefaultRes) Append(field string, values ...string) {
r.ctx.Append(field, values...)
}
func (r *DefaultRes) Attachment(filename ...string) {
r.ctx.Attachment(filename...)
}
func (r *DefaultRes) AutoFormat(body any) error {
return r.ctx.AutoFormat(body)
}
func (r *DefaultRes) CBOR(body any, ctype ...string) error {
return r.ctx.CBOR(body, ctype...)
}
func (r *DefaultRes) ClearCookie(key ...string) {
r.ctx.ClearCookie(key...)
}
func (r *DefaultRes) Cookie(cookie *Cookie) {
r.ctx.Cookie(cookie)
}
func (r *DefaultRes) Download(file string, filename ...string) error {
return r.ctx.Download(file, filename...)
}
func (r *DefaultRes) Format(handlers ...ResFmt) error {
return r.ctx.Format(handlers...)
}
func (r *DefaultRes) Get(key string, defaultValue ...string) string {
return r.ctx.GetRespHeader(key, defaultValue...)
}
func (r *DefaultRes) JSON(body any, ctype ...string) error {
return r.ctx.JSON(body, ctype...)
}
func (r *DefaultRes) JSONP(data any, callback ...string) error {
return r.ctx.JSONP(data, callback...)
}
func (r *DefaultRes) Links(link ...string) {
r.ctx.Links(link...)
}
func (r *DefaultRes) Location(path string) {
r.ctx.Location(path)
}
func (r *DefaultRes) Render(name string, bind any, layouts ...string) error {
return r.ctx.Render(name, bind, layouts...)
}
func (r *DefaultRes) Send(body []byte) error {
return r.ctx.Send(body)
}
func (r *DefaultRes) SendFile(file string, config ...SendFile) error {
return r.ctx.SendFile(file, config...)
}
func (r *DefaultRes) SendStatus(status int) error {
return r.ctx.SendStatus(status)
}
func (r *DefaultRes) SendString(body string) error {
return r.ctx.SendString(body)
}
func (r *DefaultRes) SendStreamWriter(streamWriter func(*bufio.Writer)) error {
return r.ctx.SendStreamWriter(streamWriter)
}
func (r *DefaultRes) Set(key, val string) {
r.ctx.Set(key, val)
}
func (r *DefaultRes) Status(status int) Ctx {
return r.ctx.Status(status)
}
func (r *DefaultRes) Type(extension string, charset ...string) Ctx {
return r.ctx.Type(extension, charset...)
}
func (r *DefaultRes) Vary(fields ...string) {
r.ctx.Vary(fields...)
}
func (r *DefaultRes) Write(p []byte) (int, error) {
return r.ctx.Write(p)
}
func (r *DefaultRes) Writef(f string, a ...any) (int, error) {
return r.ctx.Writef(f, a...)
}
func (r *DefaultRes) WriteString(s string) (int, error) {
return r.ctx.WriteString(s)
}
func (r *DefaultRes) XML(data any) error {
return r.ctx.XML(data)
}

38
res_interface_gen.go Normal file
View File

@ -0,0 +1,38 @@
// Code generated by ifacemaker; DO NOT EDIT.
package fiber
import (
"bufio"
)
// Res
type Res interface {
Append(field string, values ...string)
Attachment(filename ...string)
AutoFormat(body any) error
CBOR(body any, ctype ...string) error
ClearCookie(key ...string)
Cookie(cookie *Cookie)
Download(file string, filename ...string) error
Format(handlers ...ResFmt) error
Get(key string, defaultValue ...string) string
JSON(body any, ctype ...string) error
JSONP(data any, callback ...string) error
Links(link ...string)
Location(path string)
Render(name string, bind any, layouts ...string) error
Send(body []byte) error
SendFile(file string, config ...SendFile) error
SendStatus(status int) error
SendString(body string) error
SendStreamWriter(streamWriter func(*bufio.Writer)) error
Set(key, val string)
Status(status int) Ctx
Type(extension string, charset ...string) Ctx
Vary(fields ...string)
Write(p []byte) (int, error)
Writef(f string, a ...any) (int, error)
WriteString(s string) (int, error)
XML(data any) error
}

View File

@ -20,18 +20,18 @@ import (
type Router interface {
Use(args ...any) Router
Get(path string, handler Handler, middleware ...Handler) Router
Head(path string, handler Handler, middleware ...Handler) Router
Post(path string, handler Handler, middleware ...Handler) Router
Put(path string, handler Handler, middleware ...Handler) Router
Delete(path string, handler Handler, middleware ...Handler) Router
Connect(path string, handler Handler, middleware ...Handler) Router
Options(path string, handler Handler, middleware ...Handler) Router
Trace(path string, handler Handler, middleware ...Handler) Router
Patch(path string, handler Handler, middleware ...Handler) Router
Get(path string, handler Handler, handlers ...Handler) Router
Head(path string, handler Handler, handlers ...Handler) Router
Post(path string, handler Handler, handlers ...Handler) Router
Put(path string, handler Handler, handlers ...Handler) Router
Delete(path string, handler Handler, handlers ...Handler) Router
Connect(path string, handler Handler, handlers ...Handler) Router
Options(path string, handler Handler, handlers ...Handler) Router
Trace(path string, handler Handler, handlers ...Handler) Router
Patch(path string, handler Handler, handlers ...Handler) Router
Add(methods []string, path string, handler Handler, middleware ...Handler) Router
All(path string, handler Handler, middleware ...Handler) Router
Add(methods []string, path string, handler Handler, handlers ...Handler) Router
All(path string, handler Handler, handlers ...Handler) Router
Group(prefix string, handlers ...Handler) Router
@ -108,11 +108,11 @@ func (r *Route) match(detectionPath, path string, params *[maxParams]string) boo
return false
}
func (app *App) nextCustom(c CustomCtx) (bool, error) { //nolint: unparam // bool param might be useful for testing
func (app *App) nextCustom(c CustomCtx) (bool, error) { //nolint:unparam // bool param might be useful for testing
// Get stack length
tree, ok := app.treeStack[c.getMethodINT()][c.getTreePath()]
tree, ok := app.treeStack[c.getMethodInt()][c.getTreePathHash()]
if !ok {
tree = app.treeStack[c.getMethodINT()][""]
tree = app.treeStack[c.getMethodInt()][0]
}
lenr := len(tree) - 1
@ -158,9 +158,9 @@ func (app *App) nextCustom(c CustomCtx) (bool, error) { //nolint: unparam // boo
func (app *App) next(c *DefaultCtx) (bool, error) {
// Get stack length
tree, ok := app.treeStack[c.methodINT][c.treePath]
tree, ok := app.treeStack[c.methodInt][c.treePathHash]
if !ok {
tree = app.treeStack[c.methodINT][""]
tree = app.treeStack[c.methodInt][0]
}
lenTree := len(tree) - 1
@ -180,7 +180,7 @@ func (app *App) next(c *DefaultCtx) (bool, error) {
}
// Check if it matches the request path
match = route.match(c.detectionPath, c.path, &c.values)
match = route.match(utils.UnsafeString(c.detectionPath), utils.UnsafeString(c.path), &c.values)
if !match {
// No match, next route
continue
@ -202,7 +202,7 @@ func (app *App) next(c *DefaultCtx) (bool, error) {
}
// If c.Next() does not match, return 404
err := NewError(StatusNotFound, "Cannot "+c.method+" "+html.EscapeString(c.pathOriginal))
err := NewError(StatusNotFound, "Cannot "+c.Method()+" "+html.EscapeString(c.pathOriginal))
if !c.matched && app.methodExist(c) {
// If no match, scan stack again if other methods match the request
// Moved from app.handler because middleware may break the route chain
@ -221,7 +221,7 @@ func (app *App) defaultRequestHandler(rctx *fasthttp.RequestCtx) {
defer app.ReleaseCtx(ctx)
// Check if the HTTP method is valid
if ctx.methodINT == -1 {
if ctx.methodInt == -1 {
_ = ctx.SendStatus(StatusNotImplemented) //nolint:errcheck // Always return nil
return
}
@ -318,10 +318,16 @@ func (*App) copyRoute(route *Route) *Route {
}
}
func (app *App) register(methods []string, pathRaw string, group *Group, handler Handler, middleware ...Handler) {
handlers := middleware
if handler != nil {
handlers = append(handlers, handler)
func (app *App) register(methods []string, pathRaw string, group *Group, handlers ...Handler) {
// A regular route requires at least one ctx handler
if len(handlers) == 0 && group == nil {
panic(fmt.Sprintf("missing handler/middleware in route: %s\n", pathRaw))
}
// No nil handlers allowed
for _, h := range handlers {
if nil == h {
panic(fmt.Sprintf("nil handler in route: %s\n", pathRaw))
}
}
// Precompute path normalization ONCE
@ -343,17 +349,14 @@ func (app *App) register(methods []string, pathRaw string, group *Group, handler
parsedRaw := parseRoute(pathRaw, app.customConstraints...)
parsedPretty := parseRoute(pathPretty, app.customConstraints...)
isMount := group != nil && group.app != app
for _, method := range methods {
method = utils.ToUpper(method)
if method != methodUse && app.methodInt(method) == -1 {
panic(fmt.Sprintf("add: invalid http method %s\n", method))
}
isMount := group != nil && group.app != app
if len(handlers) == 0 && !isMount {
panic(fmt.Sprintf("missing handler/middleware in route: %s\n", pathRaw))
}
isUse := method == methodUse
isStar := pathClean == "/*"
isRoot := pathClean == "/"
@ -451,30 +454,28 @@ func (app *App) buildTree() *App {
// loop all the methods and stacks and create the prefix tree
for m := range app.config.RequestMethods {
tsMap := make(map[string][]*Route)
tsMap := make(map[int][]*Route)
for _, route := range app.stack[m] {
treePath := ""
if len(route.routeParser.segs) > 0 && len(route.routeParser.segs[0].Const) >= 3 {
treePath = route.routeParser.segs[0].Const[:3]
treePathHash := 0
if len(route.routeParser.segs) > 0 && len(route.routeParser.segs[0].Const) >= maxDetectionPaths {
treePathHash = int(route.routeParser.segs[0].Const[0])<<16 |
int(route.routeParser.segs[0].Const[1])<<8 |
int(route.routeParser.segs[0].Const[2])
}
// create tree stack
tsMap[treePath] = append(tsMap[treePath], route)
tsMap[treePathHash] = append(tsMap[treePathHash], route)
}
app.treeStack[m] = tsMap
}
// loop the methods and tree stacks and add global stack and sort everything
for m := range app.config.RequestMethods {
tsMap := app.treeStack[m]
for treePart := range tsMap {
if treePart != "" {
if treePart != 0 {
// merge global tree routes in current tree stack
tsMap[treePart] = uniqueRouteStack(append(tsMap[treePart], tsMap[""]...))
tsMap[treePart] = uniqueRouteStack(append(tsMap[treePart], tsMap[0]...))
}
// sort tree slices with the positions
slc := tsMap[treePart]
sort.Slice(slc, func(i, j int) bool { return slc[i].pos < slc[j].pos })
}
app.treeStack[m] = tsMap
}
app.routesRefreshed = false

View File

@ -7,7 +7,6 @@ package fiber
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
@ -31,6 +30,39 @@ func init() {
}
}
func Test_Route_Handler_Order(t *testing.T) {
t.Parallel()
app := New()
var order []int
handler1 := func(c Ctx) error {
order = append(order, 1)
return c.Next()
}
handler2 := func(c Ctx) error {
order = append(order, 2)
return c.Next()
}
handler3 := func(c Ctx) error {
order = append(order, 3)
return c.Next()
}
app.Get("/test", handler1, handler2, handler3, func(c Ctx) error {
order = append(order, 4)
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(MethodGet, "/test", nil))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
expectedOrder := []int{1, 2, 3, 4}
require.Equal(t, expectedOrder, order, "Handler order")
}
func Test_Route_Match_SameLength(t *testing.T) {
t.Parallel()
@ -294,12 +326,22 @@ func Test_Router_Register_Missing_Handler(t *testing.T) {
t.Parallel()
app := New()
defer func() {
if err := recover(); err != nil {
require.Equal(t, "missing handler/middleware in route: /doe\n", fmt.Sprintf("%v", err))
}
}()
app.register([]string{"USE"}, "/doe", nil, nil)
t.Run("No Handler", func(t *testing.T) {
t.Parallel()
require.PanicsWithValue(t, "missing handler/middleware in route: /doe\n", func() {
app.register([]string{"USE"}, "/doe", nil)
})
})
t.Run("Nil Handler", func(t *testing.T) {
t.Parallel()
require.PanicsWithValue(t, "nil handler in route: /doe\n", func() {
app.register([]string{"USE"}, "/doe", nil, nil)
})
})
}
func Test_Ensure_Router_Interface_Implementation(t *testing.T) {
@ -558,7 +600,7 @@ func Benchmark_Router_Next(b *testing.B) {
var res bool
var err error
c := app.AcquireCtx(request).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
c := app.AcquireCtx(request).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
b.ResetTimer()
for n := 0; n < b.N; n++ {
@ -614,6 +656,50 @@ func Benchmark_Router_Next_Default_Parallel(b *testing.B) {
})
}
// go test -v ./... -run=^$ -bench=Benchmark_Router_Next_Default_Immutable -benchmem -count=4
func Benchmark_Router_Next_Default_Immutable(b *testing.B) {
app := New(Config{Immutable: true})
app.Get("/", func(_ Ctx) error {
return nil
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(MethodGet)
fctx.Request.SetRequestURI("/")
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
h(fctx)
}
}
// go test -benchmem -run=^$ -bench ^Benchmark_Router_Next_Default_Parallel_Immutable$ github.com/gofiber/fiber/v3 -count=1
func Benchmark_Router_Next_Default_Parallel_Immutable(b *testing.B) {
app := New(Config{Immutable: true})
app.Get("/", func(_ Ctx) error {
return nil
})
h := app.Handler()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(MethodGet)
fctx.Request.SetRequestURI("/")
for pb.Next() {
h(fctx)
}
})
}
// go test -v ./... -run=^$ -bench=Benchmark_Route_Match -benchmem -count=4
func Benchmark_Route_Match(b *testing.B) {
var match bool
@ -783,7 +869,7 @@ func Benchmark_Router_Github_API(b *testing.B) {
for n := 0; n < b.N; n++ {
c.URI().SetPath(routesFixture.TestRoutes[i].Path)
ctx := app.AcquireCtx(c).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed
ctx := app.AcquireCtx(c).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
match, err = app.next(ctx)
app.ReleaseCtx(ctx)

322
state.go Normal file
View File

@ -0,0 +1,322 @@
package fiber
import (
"sync"
)
// State is a key-value store for Fiber's app in order to be used as a global storage for the app's dependencies.
// It's a thread-safe implementation of a map[string]any, using sync.Map.
type State struct {
dependencies sync.Map
}
// NewState creates a new instance of State.
func newState() *State {
return &State{
dependencies: sync.Map{},
}
}
// Set sets a key-value pair in the State.
func (s *State) Set(key string, value any) {
s.dependencies.Store(key, value)
}
// Get retrieves a value from the State.
func (s *State) Get(key string) (any, bool) {
return s.dependencies.Load(key)
}
// MustGet retrieves a value from the State and panics if the key is not found.
func (s *State) MustGet(key string) any {
if dep, ok := s.Get(key); ok {
return dep
}
panic("state: dependency not found!")
}
// Has checks if a key is present in the State.
// It returns a boolean indicating if the key is present.
func (s *State) Has(key string) bool {
_, ok := s.Get(key)
return ok
}
// Delete removes a key-value pair from the State.
func (s *State) Delete(key string) {
s.dependencies.Delete(key)
}
// Reset resets the State by removing all keys.
func (s *State) Reset() {
s.dependencies.Clear()
}
// Keys returns a slice containing all keys present in the State.
func (s *State) Keys() []string {
keys := make([]string, 0)
s.dependencies.Range(func(key, _ any) bool {
keyStr, ok := key.(string)
if !ok {
return false
}
keys = append(keys, keyStr)
return true
})
return keys
}
// Len returns the number of keys in the State.
func (s *State) Len() int {
length := 0
s.dependencies.Range(func(_, _ any) bool {
length++
return true
})
return length
}
// GetState retrieves a value from the State and casts it to the desired type.
// It returns the casted value and a boolean indicating if the cast was successful.
func GetState[T any](s *State, key string) (T, bool) {
dep, ok := s.Get(key)
if ok {
depT, okCast := dep.(T)
return depT, okCast
}
var zeroVal T
return zeroVal, false
}
// MustGetState retrieves a value from the State and casts it to the desired type.
// It panics if the key is not found or if the type assertion fails.
func MustGetState[T any](s *State, key string) T {
dep, ok := GetState[T](s, key)
if !ok {
panic("state: dependency not found!")
}
return dep
}
// GetStateWithDefault retrieves a value from the State,
// casting it to the desired type. If the key is not present,
// it returns the provided default value.
func GetStateWithDefault[T any](s *State, key string, defaultVal T) T {
dep, ok := GetState[T](s, key)
if !ok {
return defaultVal
}
return dep
}
// GetString retrieves a string value from the State.
// It returns the string and a boolean indicating successful type assertion.
func (s *State) GetString(key string) (string, bool) {
dep, ok := s.Get(key)
if ok {
depString, okCast := dep.(string)
return depString, okCast
}
return "", false
}
// GetInt retrieves an integer value from the State.
// It returns the int and a boolean indicating successful type assertion.
func (s *State) GetInt(key string) (int, bool) {
dep, ok := s.Get(key)
if ok {
depInt, okCast := dep.(int)
return depInt, okCast
}
return 0, false
}
// GetBool retrieves a boolean value from the State.
// It returns the bool and a boolean indicating successful type assertion.
func (s *State) GetBool(key string) (value, ok bool) { //nolint:nonamedreturns // Better idea to use named returns here
dep, ok := s.Get(key)
if ok {
depBool, okCast := dep.(bool)
return depBool, okCast
}
return false, false
}
// GetFloat64 retrieves a float64 value from the State.
// It returns the float64 and a boolean indicating successful type assertion.
func (s *State) GetFloat64(key string) (float64, bool) {
dep, ok := s.Get(key)
if ok {
depFloat64, okCast := dep.(float64)
return depFloat64, okCast
}
return 0, false
}
// GetUint retrieves a uint value from the State.
// It returns the float64 and a boolean indicating successful type assertion.
func (s *State) GetUint(key string) (uint, bool) {
dep, ok := s.Get(key)
if ok {
if depUint, okCast := dep.(uint); okCast {
return depUint, true
}
}
return 0, false
}
// GetInt8 retrieves an int8 value from the State.
// It returns the float64 and a boolean indicating successful type assertion.
func (s *State) GetInt8(key string) (int8, bool) {
dep, ok := s.Get(key)
if ok {
if depInt8, okCast := dep.(int8); okCast {
return depInt8, true
}
}
return 0, false
}
// GetInt16 retrieves an int16 value from the State.
// It returns the float64 and a boolean indicating successful type assertion.
func (s *State) GetInt16(key string) (int16, bool) {
dep, ok := s.Get(key)
if ok {
if depInt16, okCast := dep.(int16); okCast {
return depInt16, true
}
}
return 0, false
}
// GetInt32 retrieves an int32 value from the State.
// It returns the float64 and a boolean indicating successful type assertion.
func (s *State) GetInt32(key string) (int32, bool) {
dep, ok := s.Get(key)
if ok {
if depInt32, okCast := dep.(int32); okCast {
return depInt32, true
}
}
return 0, false
}
// GetInt64 retrieves an int64 value from the State.
// It returns the float64 and a boolean indicating successful type assertion.
func (s *State) GetInt64(key string) (int64, bool) {
dep, ok := s.Get(key)
if ok {
if depInt64, okCast := dep.(int64); okCast {
return depInt64, true
}
}
return 0, false
}
// GetUint8 retrieves a uint8 value from the State.
// It returns the float64 and a boolean indicating successful type assertion.
func (s *State) GetUint8(key string) (uint8, bool) {
dep, ok := s.Get(key)
if ok {
if depUint8, okCast := dep.(uint8); okCast {
return depUint8, true
}
}
return 0, false
}
// GetUint16 retrieves a uint16 value from the State.
// It returns the float64 and a boolean indicating successful type assertion.
func (s *State) GetUint16(key string) (uint16, bool) {
dep, ok := s.Get(key)
if ok {
if depUint16, okCast := dep.(uint16); okCast {
return depUint16, true
}
}
return 0, false
}
// GetUint32 retrieves a uint32 value from the State.
// It returns the float64 and a boolean indicating successful type assertion.
func (s *State) GetUint32(key string) (uint32, bool) {
dep, ok := s.Get(key)
if ok {
if depUint32, okCast := dep.(uint32); okCast {
return depUint32, true
}
}
return 0, false
}
// GetUint64 retrieves a uint64 value from the State.
// It returns the float64 and a boolean indicating successful type assertion.
func (s *State) GetUint64(key string) (uint64, bool) {
dep, ok := s.Get(key)
if ok {
if depUint64, okCast := dep.(uint64); okCast {
return depUint64, true
}
}
return 0, false
}
// GetUintptr retrieves a uintptr value from the State.
// It returns the float64 and a boolean indicating successful type assertion.
func (s *State) GetUintptr(key string) (uintptr, bool) {
dep, ok := s.Get(key)
if ok {
if depUintptr, okCast := dep.(uintptr); okCast {
return depUintptr, true
}
}
return 0, false
}
// GetFloat32 retrieves a float32 value from the State.
// It returns the float64 and a boolean indicating successful type assertion.
func (s *State) GetFloat32(key string) (float32, bool) {
dep, ok := s.Get(key)
if ok {
if depFloat32, okCast := dep.(float32); okCast {
return depFloat32, true
}
}
return 0, false
}
// GetComplex64 retrieves a complex64 value from the State.
// It returns the float64 and a boolean indicating successful type assertion.
func (s *State) GetComplex64(key string) (complex64, bool) {
dep, ok := s.Get(key)
if ok {
if depComplex64, okCast := dep.(complex64); okCast {
return depComplex64, true
}
}
return 0, false
}
// GetComplex128 retrieves a complex128 value from the State.
// It returns the float64 and a boolean indicating successful type assertion.
func (s *State) GetComplex128(key string) (complex128, bool) {
dep, ok := s.Get(key)
if ok {
if depComplex128, okCast := dep.(complex128); okCast {
return depComplex128, true
}
}
return 0, false
}

981
state_test.go Normal file
View File

@ -0,0 +1,981 @@
package fiber
import (
"strconv"
"testing"
"github.com/stretchr/testify/require"
)
func TestState_SetAndGet_WithApp(t *testing.T) {
t.Parallel()
// Create app
app := New()
// test setting and getting a value
app.State().Set("foo", "bar")
val, ok := app.State().Get("foo")
require.True(t, ok)
require.Equal(t, "bar", val)
// test key not found
_, ok = app.State().Get("unknown")
require.False(t, ok)
}
func TestState_SetAndGet(t *testing.T) {
t.Parallel()
st := newState()
// test setting and getting a value
st.Set("foo", "bar")
val, ok := st.Get("foo")
require.True(t, ok)
require.Equal(t, "bar", val)
// test key not found
_, ok = st.Get("unknown")
require.False(t, ok)
}
func TestState_GetString(t *testing.T) {
t.Parallel()
st := newState()
st.Set("str", "hello")
s, ok := st.GetString("str")
require.True(t, ok)
require.Equal(t, "hello", s)
// wrong type should return false
st.Set("num", 123)
s, ok = st.GetString("num")
require.False(t, ok)
require.Equal(t, "", s)
// missing key should return false
s, ok = st.GetString("missing")
require.False(t, ok)
require.Equal(t, "", s)
}
func TestState_GetInt(t *testing.T) {
t.Parallel()
st := newState()
st.Set("num", 456)
i, ok := st.GetInt("num")
require.True(t, ok)
require.Equal(t, 456, i)
// wrong type should return zero value
st.Set("str", "abc")
i, ok = st.GetInt("str")
require.False(t, ok)
require.Equal(t, 0, i)
// missing key should return zero value
i, ok = st.GetInt("missing")
require.False(t, ok)
require.Equal(t, 0, i)
}
func TestState_GetBool(t *testing.T) {
t.Parallel()
st := newState()
st.Set("flag", true)
b, ok := st.GetBool("flag")
require.True(t, ok)
require.True(t, b)
// wrong type
st.Set("num", 1)
b, ok = st.GetBool("num")
require.False(t, ok)
require.False(t, b)
// missing key should return false
b, ok = st.GetBool("missing")
require.False(t, ok)
require.False(t, b)
}
func TestState_GetFloat64(t *testing.T) {
t.Parallel()
st := newState()
st.Set("pi", 3.14)
f, ok := st.GetFloat64("pi")
require.True(t, ok)
require.InDelta(t, 3.14, f, 0.0001)
// wrong type should return zero value
st.Set("int", 10)
f, ok = st.GetFloat64("int")
require.False(t, ok)
require.InDelta(t, 0.0, f, 0.0001)
// missing key should return zero value
f, ok = st.GetFloat64("missing")
require.False(t, ok)
require.InDelta(t, 0.0, f, 0.0001)
}
func TestState_GetUint(t *testing.T) {
t.Parallel()
st := newState()
st.Set("uint", uint(100))
u, ok := st.GetUint("uint")
require.True(t, ok)
require.Equal(t, uint(100), u)
st.Set("wrong", "not uint")
u, ok = st.GetUint("wrong")
require.False(t, ok)
require.Equal(t, uint(0), u)
u, ok = st.GetUint("missing")
require.False(t, ok)
require.Equal(t, uint(0), u)
}
func TestState_GetInt8(t *testing.T) {
t.Parallel()
st := newState()
st.Set("int8", int8(10))
i, ok := st.GetInt8("int8")
require.True(t, ok)
require.Equal(t, int8(10), i)
st.Set("wrong", "not int8")
i, ok = st.GetInt8("wrong")
require.False(t, ok)
require.Equal(t, int8(0), i)
i, ok = st.GetInt8("missing")
require.False(t, ok)
require.Equal(t, int8(0), i)
}
func TestState_GetInt16(t *testing.T) {
t.Parallel()
st := newState()
st.Set("int16", int16(200))
i, ok := st.GetInt16("int16")
require.True(t, ok)
require.Equal(t, int16(200), i)
st.Set("wrong", "not int16")
i, ok = st.GetInt16("wrong")
require.False(t, ok)
require.Equal(t, int16(0), i)
i, ok = st.GetInt16("missing")
require.False(t, ok)
require.Equal(t, int16(0), i)
}
func TestState_GetInt32(t *testing.T) {
t.Parallel()
st := newState()
st.Set("int32", int32(3000))
i, ok := st.GetInt32("int32")
require.True(t, ok)
require.Equal(t, int32(3000), i)
st.Set("wrong", "not int32")
i, ok = st.GetInt32("wrong")
require.False(t, ok)
require.Equal(t, int32(0), i)
i, ok = st.GetInt32("missing")
require.False(t, ok)
require.Equal(t, int32(0), i)
}
func TestState_GetInt64(t *testing.T) {
t.Parallel()
st := newState()
st.Set("int64", int64(4000))
i, ok := st.GetInt64("int64")
require.True(t, ok)
require.Equal(t, int64(4000), i)
st.Set("wrong", "not int64")
i, ok = st.GetInt64("wrong")
require.False(t, ok)
require.Equal(t, int64(0), i)
i, ok = st.GetInt64("missing")
require.False(t, ok)
require.Equal(t, int64(0), i)
}
func TestState_GetUint8(t *testing.T) {
t.Parallel()
st := newState()
st.Set("uint8", uint8(20))
u, ok := st.GetUint8("uint8")
require.True(t, ok)
require.Equal(t, uint8(20), u)
st.Set("wrong", "not uint8")
u, ok = st.GetUint8("wrong")
require.False(t, ok)
require.Equal(t, uint8(0), u)
u, ok = st.GetUint8("missing")
require.False(t, ok)
require.Equal(t, uint8(0), u)
}
func TestState_GetUint16(t *testing.T) {
t.Parallel()
st := newState()
st.Set("uint16", uint16(300))
u, ok := st.GetUint16("uint16")
require.True(t, ok)
require.Equal(t, uint16(300), u)
st.Set("wrong", "not uint16")
u, ok = st.GetUint16("wrong")
require.False(t, ok)
require.Equal(t, uint16(0), u)
u, ok = st.GetUint16("missing")
require.False(t, ok)
require.Equal(t, uint16(0), u)
}
func TestState_GetUint32(t *testing.T) {
t.Parallel()
st := newState()
st.Set("uint32", uint32(400000))
u, ok := st.GetUint32("uint32")
require.True(t, ok)
require.Equal(t, uint32(400000), u)
st.Set("wrong", "not uint32")
u, ok = st.GetUint32("wrong")
require.False(t, ok)
require.Equal(t, uint32(0), u)
u, ok = st.GetUint32("missing")
require.False(t, ok)
require.Equal(t, uint32(0), u)
}
func TestState_GetUint64(t *testing.T) {
t.Parallel()
st := newState()
st.Set("uint64", uint64(5000000))
u, ok := st.GetUint64("uint64")
require.True(t, ok)
require.Equal(t, uint64(5000000), u)
st.Set("wrong", "not uint64")
u, ok = st.GetUint64("wrong")
require.False(t, ok)
require.Equal(t, uint64(0), u)
u, ok = st.GetUint64("missing")
require.False(t, ok)
require.Equal(t, uint64(0), u)
}
func TestState_GetUintptr(t *testing.T) {
t.Parallel()
st := newState()
var ptr uintptr = 12345
st.Set("uintptr", ptr)
u, ok := st.GetUintptr("uintptr")
require.True(t, ok)
require.Equal(t, ptr, u)
st.Set("wrong", "not uintptr")
u, ok = st.GetUintptr("wrong")
require.False(t, ok)
require.Equal(t, uintptr(0), u)
u, ok = st.GetUintptr("missing")
require.False(t, ok)
require.Equal(t, uintptr(0), u)
}
func TestState_GetFloat32(t *testing.T) {
t.Parallel()
st := newState()
st.Set("float32", float32(3.14))
f, ok := st.GetFloat32("float32")
require.True(t, ok)
require.InDelta(t, float32(3.14), f, 0.0001)
st.Set("wrong", "not float32")
f, ok = st.GetFloat32("wrong")
require.False(t, ok)
require.InDelta(t, float32(0), f, 0.0001)
f, ok = st.GetFloat32("missing")
require.False(t, ok)
require.InDelta(t, float32(0), f, 0.0001)
}
func TestState_GetComplex64(t *testing.T) {
t.Parallel()
st := newState()
var c complex64 = complex(2, 3)
st.Set("complex64", c)
cRes, ok := st.GetComplex64("complex64")
require.True(t, ok)
require.Equal(t, c, cRes)
st.Set("wrong", "not complex64")
cRes, ok = st.GetComplex64("wrong")
require.False(t, ok)
require.Equal(t, complex64(0), cRes)
cRes, ok = st.GetComplex64("missing")
require.False(t, ok)
require.Equal(t, complex64(0), cRes)
}
func TestState_GetComplex128(t *testing.T) {
t.Parallel()
st := newState()
c := complex(4, 5)
st.Set("complex128", c)
cRes, ok := st.GetComplex128("complex128")
require.True(t, ok)
require.Equal(t, c, cRes)
st.Set("wrong", "not complex128")
cRes, ok = st.GetComplex128("wrong")
require.False(t, ok)
require.Equal(t, complex128(0), cRes)
cRes, ok = st.GetComplex128("missing")
require.False(t, ok)
require.Equal(t, complex128(0), cRes)
}
func TestState_MustGet(t *testing.T) {
t.Parallel()
st := newState()
st.Set("exists", "value")
val := st.MustGet("exists")
require.Equal(t, "value", val)
// must-get on missing key should panic
require.Panics(t, func() {
_ = st.MustGet("missing")
})
}
func TestState_Has(t *testing.T) {
t.Parallel()
st := newState()
st.Set("key", "value")
require.True(t, st.Has("key"))
}
func TestState_Delete(t *testing.T) {
t.Parallel()
st := newState()
st.Set("key", "value")
st.Delete("key")
_, ok := st.Get("key")
require.False(t, ok)
}
func TestState_Reset(t *testing.T) {
t.Parallel()
st := newState()
st.Set("a", 1)
st.Set("b", 2)
st.Reset()
require.Equal(t, 0, st.Len())
require.Empty(t, st.Keys())
}
func TestState_Keys(t *testing.T) {
t.Parallel()
st := newState()
keys := []string{"one", "two", "three"}
for _, k := range keys {
st.Set(k, k)
}
returnedKeys := st.Keys()
require.ElementsMatch(t, keys, returnedKeys)
}
func TestState_Len(t *testing.T) {
t.Parallel()
st := newState()
require.Equal(t, 0, st.Len())
st.Set("a", "a")
require.Equal(t, 1, st.Len())
st.Set("b", "b")
require.Equal(t, 2, st.Len())
st.Delete("a")
require.Equal(t, 1, st.Len())
}
type testCase[T any] struct { //nolint:govet // It does not really matter for test
name string
key string
value any
expected T
ok bool
}
func runGenericTest[T any](t *testing.T, getter func(*State, string) (T, bool), tests []testCase[T]) {
t.Helper()
st := newState()
for _, tc := range tests {
st.Set(tc.key, tc.value)
got, ok := getter(st, tc.key)
require.Equal(t, tc.ok, ok, tc.name)
require.Equal(t, tc.expected, got, tc.name)
}
}
func TestState_GetGeneric(t *testing.T) {
t.Parallel()
runGenericTest[int](t, GetState[int], []testCase[int]{
{"int correct conversion", "num", 42, 42, true},
{"int wrong conversion from string", "str", "abc", 0, false},
})
runGenericTest[string](t, GetState[string], []testCase[string]{
{"string correct conversion", "strVal", "hello", "hello", true},
{"string wrong conversion from int", "intVal", 100, "", false},
})
runGenericTest[bool](t, GetState[bool], []testCase[bool]{
{"bool correct conversion", "flag", true, true, true},
{"bool wrong conversion from int", "intFlag", 1, false, false},
})
runGenericTest[float64](t, GetState[float64], []testCase[float64]{
{"float64 correct conversion", "pi", 3.14, 3.14, true},
{"float64 wrong conversion from int", "intVal", 10, 0.0, false},
})
}
func Test_MustGetStateGeneric(t *testing.T) {
t.Parallel()
st := newState()
st.Set("flag", true)
flag := MustGetState[bool](st, "flag")
require.True(t, flag)
// mismatched type should panic
require.Panics(t, func() {
_ = MustGetState[string](st, "flag")
})
// missing key should also panic
require.Panics(t, func() {
_ = MustGetState[string](st, "missing")
})
}
func Test_GetStateWithDefault(t *testing.T) {
t.Parallel()
st := newState()
st.Set("flag", true)
flag := GetStateWithDefault(st, "flag", false)
require.True(t, flag)
// mismatched type should return the default value
str := GetStateWithDefault(st, "flag", "default")
require.Equal(t, "default", str)
// missing key should return the default value
flag = GetStateWithDefault(st, "missing", false)
require.False(t, flag)
}
func BenchmarkState_Set(b *testing.B) {
b.ReportAllocs()
st := newState()
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, i)
}
}
func BenchmarkState_Get(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// pre-populate the state
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.Get(key)
}
}
func BenchmarkState_GetString(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// pre-populate the state
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, strconv.Itoa(i))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetString(key)
}
}
func BenchmarkState_GetInt(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// pre-populate the state
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetInt(key)
}
}
func BenchmarkState_GetBool(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// pre-populate the state
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, i%2 == 0)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetBool(key)
}
}
func BenchmarkState_GetFloat64(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// pre-populate the state
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, float64(i))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetFloat64(key)
}
}
func BenchmarkState_MustGet(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// pre-populate the state
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.MustGet(key)
}
}
func BenchmarkState_GetStateGeneric(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// pre-populate the state
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
GetState[int](st, key)
}
}
func BenchmarkState_MustGetStateGeneric(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// pre-populate the state
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
MustGetState[int](st, key)
}
}
func BenchmarkState_GetStateWithDefault(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// pre-populate the state
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
GetStateWithDefault[int](st, key, 0)
}
}
func BenchmarkState_Has(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// pre-populate the state
for i := 0; i < n; i++ {
st.Set("key"+strconv.Itoa(i), i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
st.Has("key" + strconv.Itoa(i%n))
}
}
func BenchmarkState_Delete(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
st := newState()
st.Set("a", 1)
st.Delete("a")
}
}
func BenchmarkState_Reset(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
st := newState()
// add a fixed number of keys before clearing
for j := 0; j < 100; j++ {
st.Set("key"+strconv.Itoa(j), j)
}
st.Reset()
}
}
func BenchmarkState_Keys(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
for i := 0; i < n; i++ {
st.Set("key"+strconv.Itoa(i), i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = st.Keys()
}
}
func BenchmarkState_Len(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
for i := 0; i < n; i++ {
st.Set("key"+strconv.Itoa(i), i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = st.Len()
}
}
func BenchmarkState_GetUint(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Pre-populate the state with uint values.
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, uint(i)) //nolint:gosec // This is a test
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetUint(key)
}
}
func BenchmarkState_GetInt8(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Pre-populate the state with int8 values (using modulo to stay in range).
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, int8(i%128)) //nolint:gosec // This is a test
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetInt8(key)
}
}
func BenchmarkState_GetInt16(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Pre-populate the state with int16 values.
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, int16(i)) //nolint:gosec // This is a test
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetInt16(key)
}
}
func BenchmarkState_GetInt32(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Pre-populate the state with int32 values.
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, int32(i)) //nolint:gosec // This is a test
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetInt32(key)
}
}
func BenchmarkState_GetInt64(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Pre-populate the state with int64 values.
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, int64(i))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetInt64(key)
}
}
func BenchmarkState_GetUint8(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Pre-populate the state with uint8 values.
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, uint8(i%256)) //nolint:gosec // This is a test
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetUint8(key)
}
}
func BenchmarkState_GetUint16(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Pre-populate the state with uint16 values.
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, uint16(i)) //nolint:gosec // This is a test
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetUint16(key)
}
}
func BenchmarkState_GetUint32(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Pre-populate the state with uint32 values.
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, uint32(i)) //nolint:gosec // This is a test
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetUint32(key)
}
}
func BenchmarkState_GetUint64(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Pre-populate the state with uint64 values.
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, uint64(i)) //nolint:gosec // This is a test
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetUint64(key)
}
}
func BenchmarkState_GetUintptr(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Pre-populate the state with uintptr values.
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, uintptr(i))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetUintptr(key)
}
}
func BenchmarkState_GetFloat32(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Pre-populate the state with float32 values.
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, float32(i))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetFloat32(key)
}
}
func BenchmarkState_GetComplex64(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Pre-populate the state with complex64 values.
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
// Create a complex64 value with both real and imaginary parts.
st.Set(key, complex(float32(i), float32(i)))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetComplex64(key)
}
}
func BenchmarkState_GetComplex128(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Pre-populate the state with complex128 values.
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
// Create a complex128 value with both real and imaginary parts.
st.Set(key, complex(float64(i), float64(i)))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetComplex128(key)
}
}