From 75281bd874f7f179c18adf3f55175b84409c8fb8 Mon Sep 17 00:00:00 2001 From: Juan Calderon-Perez <835733+gaby@users.noreply.github.com> Date: Sun, 30 Mar 2025 05:46:52 -0400 Subject: [PATCH 1/7] =?UTF-8?q?=F0=9F=A7=B9=20chore:=20Simplify=20HealthCh?= =?UTF-8?q?eck=20middleware=20(#3380)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Simplify middleware * Rename default endpoints --- docs/middleware/healthcheck.md | 32 ++++++++++------- docs/whats_new.md | 8 ++--- middleware/healthcheck/config.go | 12 +++---- middleware/healthcheck/healthcheck.go | 7 ++-- middleware/healthcheck/healthcheck_test.go | 42 +++++++++++----------- 5 files changed, 52 insertions(+), 49 deletions(-) diff --git a/docs/middleware/healthcheck.md b/docs/middleware/healthcheck.md index 2837c550..122f5768 100644 --- a/docs/middleware/healthcheck.md +++ b/docs/middleware/healthcheck.md @@ -27,7 +27,7 @@ Liveness, readiness and startup probes middleware for [Fiber](https://github.com ## Signatures ```go -func NewHealthChecker(config Config) fiber.Handler +func New(config Config) fiber.Handler ``` ## Examples @@ -41,38 +41,44 @@ import( ) ``` -After you initiate your [Fiber](https://github.com/gofiber/fiber) app, you can use the following possibilities: +After you initiate your [Fiber](https://github.com/gofiber/fiber) app, you can use the following options: ```go // Provide a minimal config for liveness check -app.Get(healthcheck.DefaultLivenessEndpoint, healthcheck.NewHealthChecker()) +app.Get(healthcheck.LivenessEndpoint, healthcheck.New()) + // Provide a minimal config for readiness check -app.Get(healthcheck.DefaultReadinessEndpoint, healthcheck.NewHealthChecker()) +app.Get(healthcheck.ReadinessEndpoint, healthcheck.New()) + // Provide a minimal config for startup check -app.Get(healthcheck.DefaultStartupEndpoint, healthcheck.NewHealthChecker()) +app.Get(healthcheck.StartupEndpoint, healthcheck.New()) + // Provide a minimal config for check with custom endpoint -app.Get("/live", healthcheck.NewHealthChecker()) +app.Get("/live", healthcheck.New()) // Or extend your config for customization -app.Get(healthcheck.DefaultLivenessEndpoint, healthcheck.NewHealthChecker(healthcheck.Config{ +app.Get(healthcheck.LivenessEndpoint, healthcheck.New(healthcheck.Config{ Probe: func(c fiber.Ctx) bool { return true }, })) + // And it works the same for readiness, just change the route -app.Get(healthcheck.DefaultReadinessEndpoint, healthcheck.NewHealthChecker(healthcheck.Config{ +app.Get(healthcheck.ReadinessEndpoint, healthcheck.New(healthcheck.Config{ Probe: func(c fiber.Ctx) bool { return true }, })) + // And it works the same for startup, just change the route -app.Get(healthcheck.DefaultStartupEndpoint, healthcheck.NewHealthChecker(healthcheck.Config{ +app.Get(healthcheck.StartupEndpoint, healthcheck.New(healthcheck.Config{ Probe: func(c fiber.Ctx) bool { return true }, })) + // With a custom route and custom probe -app.Get("/live", healthcheck.NewHealthChecker(healthcheck.Config{ +app.Get("/live", healthcheck.New(healthcheck.Config{ Probe: func(c fiber.Ctx) bool { return true }, @@ -81,7 +87,7 @@ app.Get("/live", healthcheck.NewHealthChecker(healthcheck.Config{ // It can also be used with app.All, although it will only respond to requests with the GET method // in case of calling the route with any method which isn't GET, the return will be 404 Not Found when app.All is used // and 405 Method Not Allowed when app.Get is used -app.All(healthcheck.DefaultReadinessEndpoint, healthcheck.NewHealthChecker(healthcheck.Config{ +app.All(healthcheck.ReadinessEndpoint, healthcheck.New(healthcheck.Config{ Probe: func(c fiber.Ctx) bool { return true }, @@ -108,7 +114,7 @@ type Config struct { // initialization and readiness checks // // Optional. Default: func(c fiber.Ctx) bool { return true } - Probe HealthChecker + Probe func(fiber.Ctx) bool } ``` @@ -117,7 +123,7 @@ type Config struct { The default configuration used by this middleware is defined as follows: ```go -func defaultProbe(fiber.Ctx) bool { return true } +func defaultProbe(_ fiber.Ctx) bool { return true } var ConfigDefault = Config{ Probe: defaultProbe, diff --git a/docs/whats_new.md b/docs/whats_new.md index bc569d3e..19f261ae 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -1564,25 +1564,25 @@ With the new version, each health check endpoint is configured separately, allow // after // Default liveness endpoint configuration -app.Get(healthcheck.DefaultLivenessEndpoint, healthcheck.NewHealthChecker(healthcheck.Config{ +app.Get(healthcheck.LivenessEndpoint, healthcheck.New(healthcheck.Config{ Probe: func(c fiber.Ctx) bool { return true }, })) // Default readiness endpoint configuration -app.Get(healthcheck.DefaultReadinessEndpoint, healthcheck.NewHealthChecker()) +app.Get(healthcheck.ReadinessEndpoint, healthcheck.New()) // New default startup endpoint configuration // Default endpoint is /startupz -app.Get(healthcheck.DefaultStartupEndpoint, healthcheck.NewHealthChecker(healthcheck.Config{ +app.Get(healthcheck.StartupEndpoint, healthcheck.New(healthcheck.Config{ Probe: func(c fiber.Ctx) bool { return serviceA.Ready() && serviceB.Ready() && ... }, })) // Custom liveness endpoint configuration -app.Get("/live", healthcheck.NewHealthChecker()) +app.Get("/live", healthcheck.New()) ``` #### Monitor diff --git a/middleware/healthcheck/config.go b/middleware/healthcheck/config.go index eba6d537..112f4039 100644 --- a/middleware/healthcheck/config.go +++ b/middleware/healthcheck/config.go @@ -18,18 +18,18 @@ type Config struct { // the application is in a state where it can handle requests (e.g., the server is up and running). // // Optional. Default: func(c fiber.Ctx) bool { return true } - Probe HealthChecker + Probe func(fiber.Ctx) bool } const ( - DefaultLivenessEndpoint = "/livez" - DefaultReadinessEndpoint = "/readyz" - DefaultStartupEndpoint = "/startupz" + LivenessEndpoint = "/livez" + ReadinessEndpoint = "/readyz" + StartupEndpoint = "/startupz" ) -func defaultProbe(fiber.Ctx) bool { return true } +func defaultProbe(_ fiber.Ctx) bool { return true } -func defaultConfigV3(config ...Config) Config { +func defaultConfig(config ...Config) Config { if len(config) < 1 { return Config{ Probe: defaultProbe, diff --git a/middleware/healthcheck/healthcheck.go b/middleware/healthcheck/healthcheck.go index 51a16d70..de2079ae 100644 --- a/middleware/healthcheck/healthcheck.go +++ b/middleware/healthcheck/healthcheck.go @@ -4,11 +4,8 @@ import ( "github.com/gofiber/fiber/v3" ) -// HealthChecker defines a function to check liveness or readiness of the application -type HealthChecker func(fiber.Ctx) bool - -func NewHealthChecker(config ...Config) fiber.Handler { - cfg := defaultConfigV3(config...) +func New(config ...Config) fiber.Handler { + cfg := defaultConfig(config...) return func(c fiber.Ctx) error { // Don't execute middleware if Next returns true diff --git a/middleware/healthcheck/healthcheck_test.go b/middleware/healthcheck/healthcheck_test.go index 07efa3de..bccfddde 100644 --- a/middleware/healthcheck/healthcheck_test.go +++ b/middleware/healthcheck/healthcheck_test.go @@ -34,9 +34,9 @@ func Test_HealthCheck_Strict_Routing_Default(t *testing.T) { StrictRouting: true, }) - app.Get(DefaultLivenessEndpoint, NewHealthChecker()) - app.Get(DefaultReadinessEndpoint, NewHealthChecker()) - app.Get(DefaultStartupEndpoint, NewHealthChecker()) + app.Get(LivenessEndpoint, New()) + app.Get(ReadinessEndpoint, New()) + app.Get(StartupEndpoint, New()) shouldGiveOK(t, app, "/readyz") shouldGiveOK(t, app, "/livez") @@ -53,9 +53,9 @@ func Test_HealthCheck_Default(t *testing.T) { t.Parallel() app := fiber.New() - app.Get(DefaultLivenessEndpoint, NewHealthChecker()) - app.Get(DefaultReadinessEndpoint, NewHealthChecker()) - app.Get(DefaultStartupEndpoint, NewHealthChecker()) + app.Get(LivenessEndpoint, New()) + app.Get(ReadinessEndpoint, New()) + app.Get(StartupEndpoint, New()) shouldGiveOK(t, app, "/readyz") shouldGiveOK(t, app, "/livez") @@ -73,12 +73,12 @@ func Test_HealthCheck_Custom(t *testing.T) { app := fiber.New() c1 := make(chan struct{}, 1) - app.Get("/live", NewHealthChecker(Config{ + app.Get("/live", New(Config{ Probe: func(_ fiber.Ctx) bool { return true }, })) - app.Get("/ready", NewHealthChecker(Config{ + app.Get("/ready", New(Config{ Probe: func(_ fiber.Ctx) bool { select { case <-c1: @@ -88,7 +88,7 @@ func Test_HealthCheck_Custom(t *testing.T) { } }, })) - app.Get(DefaultStartupEndpoint, NewHealthChecker(Config{ + app.Get(StartupEndpoint, New(Config{ Probe: func(_ fiber.Ctx) bool { return false }, @@ -123,12 +123,12 @@ func Test_HealthCheck_Custom_Nested(t *testing.T) { app := fiber.New() c1 := make(chan struct{}, 1) - app.Get("/probe/live", NewHealthChecker(Config{ + app.Get("/probe/live", New(Config{ Probe: func(_ fiber.Ctx) bool { return true }, })) - app.Get("/probe/ready", NewHealthChecker(Config{ + app.Get("/probe/ready", New(Config{ Probe: func(_ fiber.Ctx) bool { select { case <-c1: @@ -164,15 +164,15 @@ func Test_HealthCheck_Next(t *testing.T) { app := fiber.New() - checker := NewHealthChecker(Config{ + checker := New(Config{ Next: func(_ fiber.Ctx) bool { return true }, }) - app.Get(DefaultLivenessEndpoint, checker) - app.Get(DefaultReadinessEndpoint, checker) - app.Get(DefaultStartupEndpoint, checker) + app.Get(LivenessEndpoint, checker) + app.Get(ReadinessEndpoint, checker) + app.Get(StartupEndpoint, checker) // This should give not found since there are no other handlers to execute // so it's like the route isn't defined at all @@ -184,9 +184,9 @@ func Test_HealthCheck_Next(t *testing.T) { func Benchmark_HealthCheck(b *testing.B) { app := fiber.New() - app.Get(DefaultLivenessEndpoint, NewHealthChecker()) - app.Get(DefaultReadinessEndpoint, NewHealthChecker()) - app.Get(DefaultStartupEndpoint, NewHealthChecker()) + app.Get(LivenessEndpoint, New()) + app.Get(ReadinessEndpoint, New()) + app.Get(StartupEndpoint, New()) h := app.Handler() fctx := &fasthttp.RequestCtx{} @@ -206,9 +206,9 @@ func Benchmark_HealthCheck(b *testing.B) { func Benchmark_HealthCheck_Parallel(b *testing.B) { app := fiber.New() - app.Get(DefaultLivenessEndpoint, NewHealthChecker()) - app.Get(DefaultReadinessEndpoint, NewHealthChecker()) - app.Get(DefaultStartupEndpoint, NewHealthChecker()) + app.Get(LivenessEndpoint, New()) + app.Get(ReadinessEndpoint, New()) + app.Get(StartupEndpoint, New()) h := app.Handler() From d19e993597d88b7df3e83a5937f2248c57475e16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=2E=20Efe=20=C3=87etin?= Date: Mon, 31 Mar 2025 10:31:59 +0300 Subject: [PATCH 2/7] :sparkles: feat: Add support for application state management (#3360) * :sparkles: feat: add support for application state management * increase test coverage * fix linter * Fix typo * add GetStateWithDefault helper * add docs * update what's new * add has method * fix linter * update * Add missing helpers for golang built-in types * Fix lint issues * Fix unit-tests. Update documentation * Fix docs, add missing benchmarks * Fix tests file * Update default example and test * Apply suggestions from code review --------- Co-authored-by: Juan Calderon-Perez <835733+gaby@users.noreply.github.com> Co-authored-by: Juan Calderon-Perez Co-authored-by: RW --- app.go | 10 + app_test.go | 10 + docs/api/constants.md | 2 +- docs/api/state.md | 640 +++++++++++++++++++++++++++ docs/whats_new.md | 1 + state.go | 322 ++++++++++++++ state_test.go | 981 ++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 1965 insertions(+), 1 deletion(-) create mode 100644 docs/api/state.md create mode 100644 state.go create mode 100644 state_test.go diff --git a/app.go b/app.go index 43ccaf9a..ae16a618 100644 --- a/app.go +++ b/app.go @@ -106,6 +106,8 @@ type App struct { tlsHandler *TLSHandler // Mount fields mountFields *mountFields + // state management + state *State // Route stack divided by HTTP methods stack [][]*Route // Route stack divided by HTTP methods and route prefixes @@ -527,6 +529,9 @@ func New(config ...Config) *App { // Define mountFields app.mountFields = newMountFields(app) + // Define state + app.state = newState() + // Override config if provided if len(config) > 0 { app.config = config[0] @@ -964,6 +969,11 @@ func (app *App) Hooks() *Hooks { return app.hooks } +// State returns the state struct to store global data in order to share it between handlers. +func (app *App) State() *State { + return app.state +} + var ErrTestGotEmptyResponse = errors.New("test: got empty response") // TestConfig is a struct holding Test settings diff --git a/app_test.go b/app_test.go index 84606e6b..b5d5ed46 100644 --- a/app_test.go +++ b/app_test.go @@ -1890,6 +1890,16 @@ func Test_Route_Naming_Issue_2671_2685(t *testing.T) { require.Equal(t, "/simple-route", sRoute2.Path) } +func Test_App_State(t *testing.T) { + t.Parallel() + app := New() + + app.State().Set("key", "value") + str, ok := app.State().GetString("key") + require.True(t, ok) + require.Equal(t, "value", str) +} + // go test -v -run=^$ -bench=Benchmark_Communication_Flow -benchmem -count=4 func Benchmark_Communication_Flow(b *testing.B) { app := New() diff --git a/docs/api/constants.md b/docs/api/constants.md index 53bbb25c..70366de1 100644 --- a/docs/api/constants.md +++ b/docs/api/constants.md @@ -2,7 +2,7 @@ id: constants title: 📋 Constants description: Some constants for Fiber. -sidebar_position: 8 +sidebar_position: 9 --- ### HTTP methods were copied from net/http diff --git a/docs/api/state.md b/docs/api/state.md new file mode 100644 index 00000000..b22b9675 --- /dev/null +++ b/docs/api/state.md @@ -0,0 +1,640 @@ +# State Management + +This document details the state management functionality provided by Fiber, a thread-safe global key–value store used to store application dependencies and runtime data. The implementation is based on Go's `sync.Map`, ensuring concurrency safety. + +Below is the detailed description of all public methods and usage examples. + +## State Type + +`State` is a key–value store built on top of `sync.Map`. It allows storage and retrieval of dependencies and configurations in a Fiber application as well as thread–safe access to runtime data. + +### Definition + +```go +// State is a key–value store for Fiber's app, used as a global storage for the app's dependencies. +// It is a thread–safe implementation of a map[string]any, using sync.Map. +type State struct { + dependencies sync.Map +} +``` + +## Methods on State + +### Set + +Set adds or updates a key–value pair in the State. + +```go +// Set adds or updates a key–value pair in the State. +func (s *State) Set(key string, value any) +``` + +**Usage Example:** + +```go +app.State().Set("appName", "My Fiber App") +``` + +### Get + +Get retrieves a value from the State. + +```go title="Signature" +func (s *State) Get(key string) (any, bool) +``` + +**Usage Example:** + +```go +value, ok := app.State().Get("appName") +if ok { + fmt.Println("App Name:", value) +} +``` + +### MustGet + +MustGet retrieves a value from the State and panics if the key is not found. + +```go title="Signature" +func (s *State) MustGet(key string) any +``` + +**Usage Example:** + +```go +appName := app.State().MustGet("appName") +fmt.Println("App Name:", appName) +``` + +### Has + +Has checks if a key exists in the State. + +```go title="Signature"s +func (s *State) Has(key string) bool +``` + +**Usage Example:** + +```go +if app.State().Has("appName") { + fmt.Println("App Name is set.") +} +``` + +### Delete + +Delete removes a key–value pair from the State. + +```go title="Signature" +func (s *State) Delete(key string) +``` + +**Usage Example:** + +```go +app.State().Delete("obsoleteKey") +``` + +### Reset + +Reset removes all keys from the State. + +```go title="Signature" +func (s *State) Reset() +``` + +**Usage Example:** + +```go +app.State().Reset() +``` + +### Keys + +Keys returns a slice containing all keys present in the State. + +```go title="Signature" +func (s *State) Keys() []string +``` + +**Usage Example:** + +```go +keys := app.State().Keys() +fmt.Println("State Keys:", keys) +``` + +### Len + +Len returns the number of keys in the State. + +```go +// Len returns the number of keys in the State. +func (s *State) Len() int +``` + +**Usage Example:** + +```go +fmt.Printf("Total State Entries: %d\n", app.State().Len()) +``` + +### GetString + +GetString retrieves a string value from the State. It returns the string and a boolean indicating a successful type assertion. + +```go title="Signature" +func (s *State) GetString(key string) (string, bool) +``` + +**Usage Example:** + +```go +if appName, ok := app.State().GetString("appName"); ok { + fmt.Println("App Name:", appName) +} +``` + +### GetInt + +GetInt retrieves an integer value from the State. It returns the int and a boolean indicating a successful type assertion. + +```go title="Signature" +func (s *State) GetInt(key string) (int, bool) +``` + +**Usage Example:** + +```go +if count, ok := app.State().GetInt("userCount"); ok { + fmt.Printf("User Count: %d\n", count) +} +``` + +### GetBool + +GetBool retrieves a boolean value from the State. It returns the bool and a boolean indicating a successful type assertion. + +```go title="Signature" +func (s *State) GetBool(key string) (value, bool) +``` + +**Usage Example:** + +```go +if debug, ok := app.State().GetBool("debugMode"); ok { + fmt.Printf("Debug Mode: %v\n", debug) +} +``` + +### GetFloat64 + +GetFloat64 retrieves a float64 value from the State. It returns the float64 and a boolean indicating a successful type assertion. + +```go title="Signature" +func (s *State) GetFloat64(key string) (float64, bool) +``` + +**Usage Example:** + +```go title="Signature" +if ratio, ok := app.State().GetFloat64("scalingFactor"); ok { + fmt.Printf("Scaling Factor: %f\n", ratio) +} +``` + +### GetUint + +GetUint retrieves a `uint` value from the State. + +```go title="Signature" +func (s *State) GetUint(key string) (uint, bool) +``` + +**Usage Example:** + +```go +if val, ok := app.State().GetUint("maxConnections"); ok { + fmt.Printf("Max Connections: %d\n", val) +} +``` + +### GetInt8 + +GetInt8 retrieves an `int8` value from the State. + +```go title="Signature" +func (s *State) GetInt8(key string) (int8, bool) +``` + +**Usage Example:** + +```go +if val, ok := app.State().GetInt8("threshold"); ok { + fmt.Printf("Threshold: %d\n", val) +} +``` + +### GetInt16 + +GetInt16 retrieves an `int16` value from the State. + +```go title="Signature" +func (s *State) GetInt16(key string) (int16, bool) +``` + +**Usage Example:** + +```go +if val, ok := app.State().GetInt16("minValue"); ok { + fmt.Printf("Minimum Value: %d\n", val) +} +``` + +### GetInt32 + +GetInt32 retrieves an `int32` value from the State. + +```go title="Signature" +func (s *State) GetInt32(key string) (int32, bool) +``` + +**Usage Example:** + +```go +if val, ok := app.State().GetInt32("portNumber"); ok { + fmt.Printf("Port Number: %d\n", val) +} +``` + +### GetInt64 + +GetInt64 retrieves an `int64` value from the State. + +```go title="Signature" +func (s *State) GetInt64(key string) (int64, bool) +``` + +**Usage Example:** + +```go +if val, ok := app.State().GetInt64("fileSize"); ok { + fmt.Printf("File Size: %d\n", val) +} +``` + +### GetUint8 + +GetUint8 retrieves a `uint8` value from the State. + +```go title="Signature" +func (s *State) GetUint8(key string) (uint8, bool) +``` + +**Usage Example:** + +```go +if val, ok := app.State().GetUint8("byteValue"); ok { + fmt.Printf("Byte Value: %d\n", val) +} +``` + +### GetUint16 + +GetUint16 retrieves a `uint16` value from the State. + +```go title="Signature" +func (s *State) GetUint16(key string) (uint16, bool) +``` + +**Usage Example:** + +```go +if val, ok := app.State().GetUint16("limit"); ok { + fmt.Printf("Limit: %d\n", val) +} +``` + +### GetUint32 + +GetUint32 retrieves a `uint32` value from the State. + +```go title="Signature" +func (s *State) GetUint32(key string) (uint32, bool) +``` + +**Usage Example:** + +```go +if val, ok := app.State().GetUint32("timeout"); ok { + fmt.Printf("Timeout: %d\n", val) +} +``` + +### GetUint64 + +GetUint64 retrieves a `uint64` value from the State. + +```go title="Signature" +func (s *State) GetUint64(key string) (uint64, bool) +``` + +**Usage Example:** + +```go +if val, ok := app.State().GetUint64("maxSize"); ok { + fmt.Printf("Max Size: %d\n", val) +} +``` + +### GetUintptr + +GetUintptr retrieves a `uintptr` value from the State. + +```go title="Signature" +func (s *State) GetUintptr(key string) (uintptr, bool) +``` + +**Usage Example:** + +```go +if val, ok := app.State().GetUintptr("pointerValue"); ok { + fmt.Printf("Pointer Value: %d\n", val) +} +``` + +### GetFloat32 + +GetFloat32 retrieves a `float32` value from the State. + +```go title="Signature" +func (s *State) GetFloat32(key string) (float32, bool) +``` + +**Usage Example:** + +```go +if val, ok := app.State().GetFloat32("scalingFactor32"); ok { + fmt.Printf("Scaling Factor (float32): %f\n", val) +} +``` + +### GetComplex64 + +GetComplex64 retrieves a `complex64` value from the State. + +```go title="Signature" +func (s *State) GetComplex64(key string) (complex64, bool) +``` + +**Usage Example:** + +```go +if val, ok := app.State().GetComplex64("complexVal"); ok { + fmt.Printf("Complex Value (complex64): %v\n", val) +} +``` + +### GetComplex128 + +GetComplex128 retrieves a `complex128` value from the State. + +```go title="Signature" +func (s *State) GetComplex128(key string) (complex128, bool) +``` + +**Usage Example:** + +```go +if val, ok := app.State().GetComplex128("complexVal128"); ok { + fmt.Printf("Complex Value (complex128): %v\n", val) +} +``` + +## Generic Functions + +Fiber provides generic functions to retrieve state values with type safety and fallback options. + +### GetState + +GetState retrieves a value from the State and casts it to the desired type. It returns the cast value and a boolean indicating if the cast was successful. + +```go title="Signature" +func GetState[T any](s *State, key string) (T, bool) +``` + +**Usage Example:** + +```go +// Retrieve an integer value safely. +userCount, ok := GetState[int](app.State(), "userCount") +if ok { + fmt.Printf("User Count: %d\n", userCount) +} +``` + +### MustGetState + +MustGetState retrieves a value from the State and casts it to the desired type. It panics if the key is not found or if the type assertion fails. + +```go title="Signature" +func MustGetState[T any](s *State, key string) T +``` + +**Usage Example:** + +```go +// Retrieve the value or panic if it is not present. +config := MustGetState[string](app.State(), "configFile") +fmt.Println("Config File:", config) +``` + +### GetStateWithDefault + +GetStateWithDefault retrieves a value from the State, casting it to the desired type. If the key is not present, it returns the provided default value. + +```go title="Signature" +func GetStateWithDefault[T any](s *State, key string, defaultVal T) T +``` + +**Usage Example:** + +```go +// Retrieve a value with a default fallback. +requestCount := GetStateWithDefault[int](app.State(), "requestCount", 0) +fmt.Printf("Request Count: %d\n", requestCount) +``` + +## Comprehensive Examples + +### Example: Request Counter + +This example demonstrates how to track the number of requests using the State. + +```go +package main + +import ( + "fmt" + + "github.com/gofiber/fiber/v3" +) + +func main() { + app := fiber.New() + + // Initialize state with a counter. + app.State().Set("requestCount", 0) + + // Middleware: Increase counter for every request. + app.Use(func(c fiber.Ctx) error { + count, _ := c.App().State().GetInt("requestCount") + app.State().Set("requestCount", count+1) + return c.Next() + }) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("Hello World!") + }) + + app.Get("/stats", func(c fiber.Ctx) error { + count, _ := c.App().State().Get("requestCount") + return c.SendString(fmt.Sprintf("Total requests: %d", count)) + }) + + app.Listen(":3000") +} +``` + +### Example: Environment–Specific Configuration + +This example shows how to configure different settings based on the environment. + +```go +package main + +import ( + "os" + + "github.com/gofiber/fiber/v3" +) + +func main() { + app := fiber.New() + + // Determine environment. + environment := os.Getenv("ENV") + if environment == "" { + environment = "development" + } + app.State().Set("environment", environment) + + // Set environment-specific configurations. + if environment == "development" { + app.State().Set("apiUrl", "http://localhost:8080/api") + app.State().Set("debug", true) + } else { + app.State().Set("apiUrl", "https://api.production.com") + app.State().Set("debug", false) + } + + app.Get("/config", func(c fiber.Ctx) error { + config := map[string]any{ + "environment": environment, + "apiUrl": fiber.GetStateWithDefault(c.App().State(), "apiUrl", ""), + "debug": fiber.GetStateWithDefault(c.App().State(), "debug", false), + } + return c.JSON(config) + }) + + app.Listen(":3000") +} +``` + +### Example: Dependency Injection with State Management + +This example demonstrates how to use the State for dependency injection in a Fiber application. + +```go +package main + +import ( + "context" + "fmt" + "log" + + "github.com/gofiber/fiber/v3" + "github.com/redis/go-redis/v9" +) + +type User struct { + ID int `query:"id"` + Name string `query:"name"` + Email string `query:"email"` +} + +func main() { + app := fiber.New() + ctx := context.Background() + + // Initialize Redis client. + rdb := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + Password: "", + DB: 0, + }) + + // Check the Redis connection. + if err := rdb.Ping(ctx).Err(); err != nil { + log.Fatalf("Could not connect to Redis: %v", err) + } + + // Inject the Redis client into Fiber's State for dependency injection. + app.State().Set("redis", rdb) + + app.Get("/user/create", func(c fiber.Ctx) error { + var user User + if err := c.Bind().Query(&user); err != nil { + return c.Status(fiber.StatusBadRequest).SendString(err.Error()) + } + + // Save the user to the database. + rdb, ok := fiber.GetState[*redis.Client](c.App().State(), "redis") + if !ok { + return c.Status(fiber.StatusInternalServerError).SendString("Redis client not found") + } + + // Save the user to the database. + key := fmt.Sprintf("user:%d", user.ID) + err := rdb.HSet(ctx, key, "name", user.Name, "email", user.Email).Err() + if err != nil { + return c.Status(fiber.StatusInternalServerError).SendString(err.Error()) + } + + return c.JSON(user) + }) + + app.Get("/user/:id", func(c fiber.Ctx) error { + id := c.Params("id") + + rdb, ok := fiber.GetState[*redis.Client](c.App().State(), "redis") + if !ok { + return c.Status(fiber.StatusInternalServerError).SendString("Redis client not found") + } + + key := fmt.Sprintf("user:%s", id) + user, err := rdb.HGetAll(ctx, key).Result() + if err == redis.Nil { + return c.Status(fiber.StatusNotFound).SendString("User not found") + } else if err != nil { + return c.Status(fiber.StatusInternalServerError).SendString(err.Error()) + } + + return c.JSON(user) + }) + + app.Listen(":3000") +} +``` diff --git a/docs/whats_new.md b/docs/whats_new.md index 19f261ae..4bf28927 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -59,6 +59,7 @@ We have made several changes to the Fiber app, including: - **RegisterCustomBinder**: Allows for the registration of custom binders. - **RegisterCustomConstraint**: Allows for the registration of custom constraints. - **NewCtxFunc**: Introduces a new context function. +- **State**: Provides a global state for the application, which can be used to store and retrieve data across the application. Check out the [State](./api/state) method for further details. ### Removed Methods diff --git a/state.go b/state.go new file mode 100644 index 00000000..43687603 --- /dev/null +++ b/state.go @@ -0,0 +1,322 @@ +package fiber + +import ( + "sync" +) + +// State is a key-value store for Fiber's app in order to be used as a global storage for the app's dependencies. +// It's a thread-safe implementation of a map[string]any, using sync.Map. +type State struct { + dependencies sync.Map +} + +// NewState creates a new instance of State. +func newState() *State { + return &State{ + dependencies: sync.Map{}, + } +} + +// Set sets a key-value pair in the State. +func (s *State) Set(key string, value any) { + s.dependencies.Store(key, value) +} + +// Get retrieves a value from the State. +func (s *State) Get(key string) (any, bool) { + return s.dependencies.Load(key) +} + +// MustGet retrieves a value from the State and panics if the key is not found. +func (s *State) MustGet(key string) any { + if dep, ok := s.Get(key); ok { + return dep + } + + panic("state: dependency not found!") +} + +// Has checks if a key is present in the State. +// It returns a boolean indicating if the key is present. +func (s *State) Has(key string) bool { + _, ok := s.Get(key) + return ok +} + +// Delete removes a key-value pair from the State. +func (s *State) Delete(key string) { + s.dependencies.Delete(key) +} + +// Reset resets the State by removing all keys. +func (s *State) Reset() { + s.dependencies.Clear() +} + +// Keys returns a slice containing all keys present in the State. +func (s *State) Keys() []string { + keys := make([]string, 0) + s.dependencies.Range(func(key, _ any) bool { + keyStr, ok := key.(string) + if !ok { + return false + } + + keys = append(keys, keyStr) + return true + }) + + return keys +} + +// Len returns the number of keys in the State. +func (s *State) Len() int { + length := 0 + s.dependencies.Range(func(_, _ any) bool { + length++ + return true + }) + + return length +} + +// GetState retrieves a value from the State and casts it to the desired type. +// It returns the casted value and a boolean indicating if the cast was successful. +func GetState[T any](s *State, key string) (T, bool) { + dep, ok := s.Get(key) + + if ok { + depT, okCast := dep.(T) + return depT, okCast + } + + var zeroVal T + return zeroVal, false +} + +// MustGetState retrieves a value from the State and casts it to the desired type. +// It panics if the key is not found or if the type assertion fails. +func MustGetState[T any](s *State, key string) T { + dep, ok := GetState[T](s, key) + if !ok { + panic("state: dependency not found!") + } + + return dep +} + +// GetStateWithDefault retrieves a value from the State, +// casting it to the desired type. If the key is not present, +// it returns the provided default value. +func GetStateWithDefault[T any](s *State, key string, defaultVal T) T { + dep, ok := GetState[T](s, key) + if !ok { + return defaultVal + } + + return dep +} + +// GetString retrieves a string value from the State. +// It returns the string and a boolean indicating successful type assertion. +func (s *State) GetString(key string) (string, bool) { + dep, ok := s.Get(key) + if ok { + depString, okCast := dep.(string) + return depString, okCast + } + + return "", false +} + +// GetInt retrieves an integer value from the State. +// It returns the int and a boolean indicating successful type assertion. +func (s *State) GetInt(key string) (int, bool) { + dep, ok := s.Get(key) + if ok { + depInt, okCast := dep.(int) + return depInt, okCast + } + + return 0, false +} + +// GetBool retrieves a boolean value from the State. +// It returns the bool and a boolean indicating successful type assertion. +func (s *State) GetBool(key string) (value, ok bool) { //nolint:nonamedreturns // Better idea to use named returns here + dep, ok := s.Get(key) + if ok { + depBool, okCast := dep.(bool) + return depBool, okCast + } + + return false, false +} + +// GetFloat64 retrieves a float64 value from the State. +// It returns the float64 and a boolean indicating successful type assertion. +func (s *State) GetFloat64(key string) (float64, bool) { + dep, ok := s.Get(key) + if ok { + depFloat64, okCast := dep.(float64) + return depFloat64, okCast + } + + return 0, false +} + +// GetUint retrieves a uint value from the State. +// It returns the float64 and a boolean indicating successful type assertion. +func (s *State) GetUint(key string) (uint, bool) { + dep, ok := s.Get(key) + if ok { + if depUint, okCast := dep.(uint); okCast { + return depUint, true + } + } + return 0, false +} + +// GetInt8 retrieves an int8 value from the State. +// It returns the float64 and a boolean indicating successful type assertion. +func (s *State) GetInt8(key string) (int8, bool) { + dep, ok := s.Get(key) + if ok { + if depInt8, okCast := dep.(int8); okCast { + return depInt8, true + } + } + return 0, false +} + +// GetInt16 retrieves an int16 value from the State. +// It returns the float64 and a boolean indicating successful type assertion. +func (s *State) GetInt16(key string) (int16, bool) { + dep, ok := s.Get(key) + if ok { + if depInt16, okCast := dep.(int16); okCast { + return depInt16, true + } + } + return 0, false +} + +// GetInt32 retrieves an int32 value from the State. +// It returns the float64 and a boolean indicating successful type assertion. +func (s *State) GetInt32(key string) (int32, bool) { + dep, ok := s.Get(key) + if ok { + if depInt32, okCast := dep.(int32); okCast { + return depInt32, true + } + } + return 0, false +} + +// GetInt64 retrieves an int64 value from the State. +// It returns the float64 and a boolean indicating successful type assertion. +func (s *State) GetInt64(key string) (int64, bool) { + dep, ok := s.Get(key) + if ok { + if depInt64, okCast := dep.(int64); okCast { + return depInt64, true + } + } + return 0, false +} + +// GetUint8 retrieves a uint8 value from the State. +// It returns the float64 and a boolean indicating successful type assertion. +func (s *State) GetUint8(key string) (uint8, bool) { + dep, ok := s.Get(key) + if ok { + if depUint8, okCast := dep.(uint8); okCast { + return depUint8, true + } + } + return 0, false +} + +// GetUint16 retrieves a uint16 value from the State. +// It returns the float64 and a boolean indicating successful type assertion. +func (s *State) GetUint16(key string) (uint16, bool) { + dep, ok := s.Get(key) + if ok { + if depUint16, okCast := dep.(uint16); okCast { + return depUint16, true + } + } + return 0, false +} + +// GetUint32 retrieves a uint32 value from the State. +// It returns the float64 and a boolean indicating successful type assertion. +func (s *State) GetUint32(key string) (uint32, bool) { + dep, ok := s.Get(key) + if ok { + if depUint32, okCast := dep.(uint32); okCast { + return depUint32, true + } + } + return 0, false +} + +// GetUint64 retrieves a uint64 value from the State. +// It returns the float64 and a boolean indicating successful type assertion. +func (s *State) GetUint64(key string) (uint64, bool) { + dep, ok := s.Get(key) + if ok { + if depUint64, okCast := dep.(uint64); okCast { + return depUint64, true + } + } + return 0, false +} + +// GetUintptr retrieves a uintptr value from the State. +// It returns the float64 and a boolean indicating successful type assertion. +func (s *State) GetUintptr(key string) (uintptr, bool) { + dep, ok := s.Get(key) + if ok { + if depUintptr, okCast := dep.(uintptr); okCast { + return depUintptr, true + } + } + return 0, false +} + +// GetFloat32 retrieves a float32 value from the State. +// It returns the float64 and a boolean indicating successful type assertion. +func (s *State) GetFloat32(key string) (float32, bool) { + dep, ok := s.Get(key) + if ok { + if depFloat32, okCast := dep.(float32); okCast { + return depFloat32, true + } + } + return 0, false +} + +// GetComplex64 retrieves a complex64 value from the State. +// It returns the float64 and a boolean indicating successful type assertion. +func (s *State) GetComplex64(key string) (complex64, bool) { + dep, ok := s.Get(key) + if ok { + if depComplex64, okCast := dep.(complex64); okCast { + return depComplex64, true + } + } + return 0, false +} + +// GetComplex128 retrieves a complex128 value from the State. +// It returns the float64 and a boolean indicating successful type assertion. +func (s *State) GetComplex128(key string) (complex128, bool) { + dep, ok := s.Get(key) + if ok { + if depComplex128, okCast := dep.(complex128); okCast { + return depComplex128, true + } + } + return 0, false +} diff --git a/state_test.go b/state_test.go new file mode 100644 index 00000000..e96ea0dd --- /dev/null +++ b/state_test.go @@ -0,0 +1,981 @@ +package fiber + +import ( + "strconv" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestState_SetAndGet_WithApp(t *testing.T) { + t.Parallel() + // Create app + app := New() + + // test setting and getting a value + app.State().Set("foo", "bar") + val, ok := app.State().Get("foo") + require.True(t, ok) + require.Equal(t, "bar", val) + + // test key not found + _, ok = app.State().Get("unknown") + require.False(t, ok) +} + +func TestState_SetAndGet(t *testing.T) { + t.Parallel() + st := newState() + + // test setting and getting a value + st.Set("foo", "bar") + val, ok := st.Get("foo") + require.True(t, ok) + require.Equal(t, "bar", val) + + // test key not found + _, ok = st.Get("unknown") + require.False(t, ok) +} + +func TestState_GetString(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("str", "hello") + s, ok := st.GetString("str") + require.True(t, ok) + require.Equal(t, "hello", s) + + // wrong type should return false + st.Set("num", 123) + s, ok = st.GetString("num") + require.False(t, ok) + require.Equal(t, "", s) + + // missing key should return false + s, ok = st.GetString("missing") + require.False(t, ok) + require.Equal(t, "", s) +} + +func TestState_GetInt(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("num", 456) + i, ok := st.GetInt("num") + require.True(t, ok) + require.Equal(t, 456, i) + + // wrong type should return zero value + st.Set("str", "abc") + i, ok = st.GetInt("str") + require.False(t, ok) + require.Equal(t, 0, i) + + // missing key should return zero value + i, ok = st.GetInt("missing") + require.False(t, ok) + require.Equal(t, 0, i) +} + +func TestState_GetBool(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("flag", true) + b, ok := st.GetBool("flag") + require.True(t, ok) + require.True(t, b) + + // wrong type + st.Set("num", 1) + b, ok = st.GetBool("num") + require.False(t, ok) + require.False(t, b) + + // missing key should return false + b, ok = st.GetBool("missing") + require.False(t, ok) + require.False(t, b) +} + +func TestState_GetFloat64(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("pi", 3.14) + f, ok := st.GetFloat64("pi") + require.True(t, ok) + require.InDelta(t, 3.14, f, 0.0001) + + // wrong type should return zero value + st.Set("int", 10) + f, ok = st.GetFloat64("int") + require.False(t, ok) + require.InDelta(t, 0.0, f, 0.0001) + + // missing key should return zero value + f, ok = st.GetFloat64("missing") + require.False(t, ok) + require.InDelta(t, 0.0, f, 0.0001) +} + +func TestState_GetUint(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("uint", uint(100)) + u, ok := st.GetUint("uint") + require.True(t, ok) + require.Equal(t, uint(100), u) + + st.Set("wrong", "not uint") + u, ok = st.GetUint("wrong") + require.False(t, ok) + require.Equal(t, uint(0), u) + + u, ok = st.GetUint("missing") + require.False(t, ok) + require.Equal(t, uint(0), u) +} + +func TestState_GetInt8(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("int8", int8(10)) + i, ok := st.GetInt8("int8") + require.True(t, ok) + require.Equal(t, int8(10), i) + + st.Set("wrong", "not int8") + i, ok = st.GetInt8("wrong") + require.False(t, ok) + require.Equal(t, int8(0), i) + + i, ok = st.GetInt8("missing") + require.False(t, ok) + require.Equal(t, int8(0), i) +} + +func TestState_GetInt16(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("int16", int16(200)) + i, ok := st.GetInt16("int16") + require.True(t, ok) + require.Equal(t, int16(200), i) + + st.Set("wrong", "not int16") + i, ok = st.GetInt16("wrong") + require.False(t, ok) + require.Equal(t, int16(0), i) + + i, ok = st.GetInt16("missing") + require.False(t, ok) + require.Equal(t, int16(0), i) +} + +func TestState_GetInt32(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("int32", int32(3000)) + i, ok := st.GetInt32("int32") + require.True(t, ok) + require.Equal(t, int32(3000), i) + + st.Set("wrong", "not int32") + i, ok = st.GetInt32("wrong") + require.False(t, ok) + require.Equal(t, int32(0), i) + + i, ok = st.GetInt32("missing") + require.False(t, ok) + require.Equal(t, int32(0), i) +} + +func TestState_GetInt64(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("int64", int64(4000)) + i, ok := st.GetInt64("int64") + require.True(t, ok) + require.Equal(t, int64(4000), i) + + st.Set("wrong", "not int64") + i, ok = st.GetInt64("wrong") + require.False(t, ok) + require.Equal(t, int64(0), i) + + i, ok = st.GetInt64("missing") + require.False(t, ok) + require.Equal(t, int64(0), i) +} + +func TestState_GetUint8(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("uint8", uint8(20)) + u, ok := st.GetUint8("uint8") + require.True(t, ok) + require.Equal(t, uint8(20), u) + + st.Set("wrong", "not uint8") + u, ok = st.GetUint8("wrong") + require.False(t, ok) + require.Equal(t, uint8(0), u) + + u, ok = st.GetUint8("missing") + require.False(t, ok) + require.Equal(t, uint8(0), u) +} + +func TestState_GetUint16(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("uint16", uint16(300)) + u, ok := st.GetUint16("uint16") + require.True(t, ok) + require.Equal(t, uint16(300), u) + + st.Set("wrong", "not uint16") + u, ok = st.GetUint16("wrong") + require.False(t, ok) + require.Equal(t, uint16(0), u) + + u, ok = st.GetUint16("missing") + require.False(t, ok) + require.Equal(t, uint16(0), u) +} + +func TestState_GetUint32(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("uint32", uint32(400000)) + u, ok := st.GetUint32("uint32") + require.True(t, ok) + require.Equal(t, uint32(400000), u) + + st.Set("wrong", "not uint32") + u, ok = st.GetUint32("wrong") + require.False(t, ok) + require.Equal(t, uint32(0), u) + + u, ok = st.GetUint32("missing") + require.False(t, ok) + require.Equal(t, uint32(0), u) +} + +func TestState_GetUint64(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("uint64", uint64(5000000)) + u, ok := st.GetUint64("uint64") + require.True(t, ok) + require.Equal(t, uint64(5000000), u) + + st.Set("wrong", "not uint64") + u, ok = st.GetUint64("wrong") + require.False(t, ok) + require.Equal(t, uint64(0), u) + + u, ok = st.GetUint64("missing") + require.False(t, ok) + require.Equal(t, uint64(0), u) +} + +func TestState_GetUintptr(t *testing.T) { + t.Parallel() + st := newState() + + var ptr uintptr = 12345 + st.Set("uintptr", ptr) + u, ok := st.GetUintptr("uintptr") + require.True(t, ok) + require.Equal(t, ptr, u) + + st.Set("wrong", "not uintptr") + u, ok = st.GetUintptr("wrong") + require.False(t, ok) + require.Equal(t, uintptr(0), u) + + u, ok = st.GetUintptr("missing") + require.False(t, ok) + require.Equal(t, uintptr(0), u) +} + +func TestState_GetFloat32(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("float32", float32(3.14)) + f, ok := st.GetFloat32("float32") + require.True(t, ok) + require.InDelta(t, float32(3.14), f, 0.0001) + + st.Set("wrong", "not float32") + f, ok = st.GetFloat32("wrong") + require.False(t, ok) + require.InDelta(t, float32(0), f, 0.0001) + + f, ok = st.GetFloat32("missing") + require.False(t, ok) + require.InDelta(t, float32(0), f, 0.0001) +} + +func TestState_GetComplex64(t *testing.T) { + t.Parallel() + st := newState() + + var c complex64 = complex(2, 3) + st.Set("complex64", c) + cRes, ok := st.GetComplex64("complex64") + require.True(t, ok) + require.Equal(t, c, cRes) + + st.Set("wrong", "not complex64") + cRes, ok = st.GetComplex64("wrong") + require.False(t, ok) + require.Equal(t, complex64(0), cRes) + + cRes, ok = st.GetComplex64("missing") + require.False(t, ok) + require.Equal(t, complex64(0), cRes) +} + +func TestState_GetComplex128(t *testing.T) { + t.Parallel() + st := newState() + + c := complex(4, 5) + st.Set("complex128", c) + cRes, ok := st.GetComplex128("complex128") + require.True(t, ok) + require.Equal(t, c, cRes) + + st.Set("wrong", "not complex128") + cRes, ok = st.GetComplex128("wrong") + require.False(t, ok) + require.Equal(t, complex128(0), cRes) + + cRes, ok = st.GetComplex128("missing") + require.False(t, ok) + require.Equal(t, complex128(0), cRes) +} + +func TestState_MustGet(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("exists", "value") + val := st.MustGet("exists") + require.Equal(t, "value", val) + + // must-get on missing key should panic + require.Panics(t, func() { + _ = st.MustGet("missing") + }) +} + +func TestState_Has(t *testing.T) { + t.Parallel() + + st := newState() + + st.Set("key", "value") + require.True(t, st.Has("key")) +} + +func TestState_Delete(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("key", "value") + st.Delete("key") + _, ok := st.Get("key") + require.False(t, ok) +} + +func TestState_Reset(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("a", 1) + st.Set("b", 2) + st.Reset() + require.Equal(t, 0, st.Len()) + require.Empty(t, st.Keys()) +} + +func TestState_Keys(t *testing.T) { + t.Parallel() + st := newState() + + keys := []string{"one", "two", "three"} + for _, k := range keys { + st.Set(k, k) + } + + returnedKeys := st.Keys() + require.ElementsMatch(t, keys, returnedKeys) +} + +func TestState_Len(t *testing.T) { + t.Parallel() + st := newState() + + require.Equal(t, 0, st.Len()) + + st.Set("a", "a") + require.Equal(t, 1, st.Len()) + + st.Set("b", "b") + require.Equal(t, 2, st.Len()) + + st.Delete("a") + require.Equal(t, 1, st.Len()) +} + +type testCase[T any] struct { //nolint:govet // It does not really matter for test + name string + key string + value any + expected T + ok bool +} + +func runGenericTest[T any](t *testing.T, getter func(*State, string) (T, bool), tests []testCase[T]) { + t.Helper() + + st := newState() + for _, tc := range tests { + st.Set(tc.key, tc.value) + got, ok := getter(st, tc.key) + require.Equal(t, tc.ok, ok, tc.name) + require.Equal(t, tc.expected, got, tc.name) + } +} + +func TestState_GetGeneric(t *testing.T) { + t.Parallel() + + runGenericTest[int](t, GetState[int], []testCase[int]{ + {"int correct conversion", "num", 42, 42, true}, + {"int wrong conversion from string", "str", "abc", 0, false}, + }) + + runGenericTest[string](t, GetState[string], []testCase[string]{ + {"string correct conversion", "strVal", "hello", "hello", true}, + {"string wrong conversion from int", "intVal", 100, "", false}, + }) + + runGenericTest[bool](t, GetState[bool], []testCase[bool]{ + {"bool correct conversion", "flag", true, true, true}, + {"bool wrong conversion from int", "intFlag", 1, false, false}, + }) + + runGenericTest[float64](t, GetState[float64], []testCase[float64]{ + {"float64 correct conversion", "pi", 3.14, 3.14, true}, + {"float64 wrong conversion from int", "intVal", 10, 0.0, false}, + }) +} + +func Test_MustGetStateGeneric(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("flag", true) + flag := MustGetState[bool](st, "flag") + require.True(t, flag) + + // mismatched type should panic + require.Panics(t, func() { + _ = MustGetState[string](st, "flag") + }) + + // missing key should also panic + require.Panics(t, func() { + _ = MustGetState[string](st, "missing") + }) +} + +func Test_GetStateWithDefault(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("flag", true) + flag := GetStateWithDefault(st, "flag", false) + require.True(t, flag) + + // mismatched type should return the default value + str := GetStateWithDefault(st, "flag", "default") + require.Equal(t, "default", str) + + // missing key should return the default value + flag = GetStateWithDefault(st, "missing", false) + require.False(t, flag) +} + +func BenchmarkState_Set(b *testing.B) { + b.ReportAllocs() + + st := newState() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, i) + } +} + +func BenchmarkState_Get(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + // pre-populate the state + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.Get(key) + } +} + +func BenchmarkState_GetString(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + // pre-populate the state + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, strconv.Itoa(i)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetString(key) + } +} + +func BenchmarkState_GetInt(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + // pre-populate the state + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetInt(key) + } +} + +func BenchmarkState_GetBool(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + // pre-populate the state + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, i%2 == 0) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetBool(key) + } +} + +func BenchmarkState_GetFloat64(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + // pre-populate the state + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, float64(i)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetFloat64(key) + } +} + +func BenchmarkState_MustGet(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + // pre-populate the state + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.MustGet(key) + } +} + +func BenchmarkState_GetStateGeneric(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + // pre-populate the state + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + GetState[int](st, key) + } +} + +func BenchmarkState_MustGetStateGeneric(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + // pre-populate the state + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + MustGetState[int](st, key) + } +} + +func BenchmarkState_GetStateWithDefault(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + // pre-populate the state + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + GetStateWithDefault[int](st, key, 0) + } +} + +func BenchmarkState_Has(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + // pre-populate the state + for i := 0; i < n; i++ { + st.Set("key"+strconv.Itoa(i), i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + st.Has("key" + strconv.Itoa(i%n)) + } +} + +func BenchmarkState_Delete(b *testing.B) { + b.ReportAllocs() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + st := newState() + st.Set("a", 1) + st.Delete("a") + } +} + +func BenchmarkState_Reset(b *testing.B) { + b.ReportAllocs() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + st := newState() + // add a fixed number of keys before clearing + for j := 0; j < 100; j++ { + st.Set("key"+strconv.Itoa(j), j) + } + st.Reset() + } +} + +func BenchmarkState_Keys(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + for i := 0; i < n; i++ { + st.Set("key"+strconv.Itoa(i), i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = st.Keys() + } +} + +func BenchmarkState_Len(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + for i := 0; i < n; i++ { + st.Set("key"+strconv.Itoa(i), i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = st.Len() + } +} + +func BenchmarkState_GetUint(b *testing.B) { + b.ReportAllocs() + st := newState() + n := 1000 + // Pre-populate the state with uint values. + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, uint(i)) //nolint:gosec // This is a test + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetUint(key) + } +} + +func BenchmarkState_GetInt8(b *testing.B) { + b.ReportAllocs() + st := newState() + n := 1000 + // Pre-populate the state with int8 values (using modulo to stay in range). + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, int8(i%128)) //nolint:gosec // This is a test + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetInt8(key) + } +} + +func BenchmarkState_GetInt16(b *testing.B) { + b.ReportAllocs() + st := newState() + n := 1000 + // Pre-populate the state with int16 values. + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, int16(i)) //nolint:gosec // This is a test + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetInt16(key) + } +} + +func BenchmarkState_GetInt32(b *testing.B) { + b.ReportAllocs() + st := newState() + n := 1000 + // Pre-populate the state with int32 values. + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, int32(i)) //nolint:gosec // This is a test + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetInt32(key) + } +} + +func BenchmarkState_GetInt64(b *testing.B) { + b.ReportAllocs() + st := newState() + n := 1000 + // Pre-populate the state with int64 values. + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, int64(i)) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetInt64(key) + } +} + +func BenchmarkState_GetUint8(b *testing.B) { + b.ReportAllocs() + st := newState() + n := 1000 + // Pre-populate the state with uint8 values. + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, uint8(i%256)) //nolint:gosec // This is a test + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetUint8(key) + } +} + +func BenchmarkState_GetUint16(b *testing.B) { + b.ReportAllocs() + st := newState() + n := 1000 + // Pre-populate the state with uint16 values. + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, uint16(i)) //nolint:gosec // This is a test + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetUint16(key) + } +} + +func BenchmarkState_GetUint32(b *testing.B) { + b.ReportAllocs() + st := newState() + n := 1000 + // Pre-populate the state with uint32 values. + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, uint32(i)) //nolint:gosec // This is a test + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetUint32(key) + } +} + +func BenchmarkState_GetUint64(b *testing.B) { + b.ReportAllocs() + st := newState() + n := 1000 + // Pre-populate the state with uint64 values. + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, uint64(i)) //nolint:gosec // This is a test + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetUint64(key) + } +} + +func BenchmarkState_GetUintptr(b *testing.B) { + b.ReportAllocs() + st := newState() + n := 1000 + // Pre-populate the state with uintptr values. + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, uintptr(i)) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetUintptr(key) + } +} + +func BenchmarkState_GetFloat32(b *testing.B) { + b.ReportAllocs() + st := newState() + n := 1000 + // Pre-populate the state with float32 values. + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, float32(i)) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetFloat32(key) + } +} + +func BenchmarkState_GetComplex64(b *testing.B) { + b.ReportAllocs() + st := newState() + n := 1000 + // Pre-populate the state with complex64 values. + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + // Create a complex64 value with both real and imaginary parts. + st.Set(key, complex(float32(i), float32(i))) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetComplex64(key) + } +} + +func BenchmarkState_GetComplex128(b *testing.B) { + b.ReportAllocs() + st := newState() + n := 1000 + // Pre-populate the state with complex128 values. + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + // Create a complex128 value with both real and imaginary parts. + st.Set(key, complex(float64(i), float64(i))) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetComplex128(key) + } +} From c2e39b75703219f6d1fcec11ab4b6f8aaa87c2f0 Mon Sep 17 00:00:00 2001 From: Kashiwa <13825170+ksw2000@users.noreply.github.com> Date: Mon, 31 Mar 2025 15:49:40 +0800 Subject: [PATCH 3/7] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Refactor:=20Add=20find?= =?UTF-8?q?NextNonEscapedCharPosition=20for=20single-byte=20charset=20case?= =?UTF-8?q?s=20(#3378)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ♻️ Refactor: add findNextNonEscapedCharsetPosition to process a single-byte parameter ``` goos: linux goarch: amd64 pkg: github.com/gofiber/fiber/v3 cpu: AMD EPYC 9J14 96-Core Processor │ old.txt │ new.txt │ │ sec/op │ sec/op vs base │ _RoutePatternMatch//api/v1/const_|_match_|_/api/v1/const-16 160.4n ± 1% 159.0n ± 0% -0.84% (p=0.000 n=20) _RoutePatternMatch//api/v1/const_|_not_match_|_/api/v1-16 151.6n ± 0% 150.8n ± 0% -0.53% (p=0.005 n=20) _RoutePatternMatch//api/v1/const_|_not_match_|_/api/v1/-16 151.7n ± 0% 150.6n ± 0% -0.73% (p=0.000 n=20) _RoutePatternMatch//api/v1/const_|_not_match_|_/api/v1/something-16 162.3n ± 0% 160.8n ± 0% -0.96% (p=0.000 n=20) _RoutePatternMatch//api/:param/fixedEnd_|_match_|_/api/abc/fixedEnd-16 452.9n ± 1% 435.8n ± 0% -3.79% (p=0.000 n=20) _RoutePatternMatch//api/:param/fixedEnd_|_not_match_|_/api/abc/def/fixedEnd-16 455.6n ± 1% 435.7n ± 0% -4.38% (p=0.000 n=20) _RoutePatternMatch//api/v1/:param/*_|_match_|_/api/v1/entity-16 524.4n ± 1% 507.6n ± 1% -3.19% (p=0.000 n=20) _RoutePatternMatch//api/v1/:param/*_|_match_|_/api/v1/entity/-16 528.2n ± 0% 508.7n ± 0% -3.69% (p=0.000 n=20) _RoutePatternMatch//api/v1/:param/*_|_match_|_/api/v1/entity/1-16 528.1n ± 0% 510.6n ± 0% -3.31% (p=0.000 n=20) _RoutePatternMatch//api/v1/:param/*_|_not_match_|_/api/v-16 500.3n ± 0% 489.0n ± 0% -2.27% (p=0.000 n=20) _RoutePatternMatch//api/v1/:param/*_|_not_match_|_/api/v2-16 502.1n ± 0% 489.9n ± 0% -2.44% (p=0.000 n=20) _RoutePatternMatch//api/v1/:param/*_|_not_match_|_/api/v1/-16 515.5n ± 0% 498.8n ± 0% -3.24% (p=0.000 n=20) geomean 339.4n 331.1n -2.46% │ old.txt │ new.txt │ │ B/op │ B/op vs base │ _RoutePatternMatch//api/v1/const_|_match_|_/api/v1/const-16 144.0 ± 0% 144.0 ± 0% ~ (p=1.000 n=20) ¹ _RoutePatternMatch//api/v1/const_|_not_match_|_/api/v1-16 136.0 ± 0% 136.0 ± 0% ~ (p=1.000 n=20) ¹ _RoutePatternMatch//api/v1/const_|_not_match_|_/api/v1/-16 136.0 ± 0% 136.0 ± 0% ~ (p=1.000 n=20) ¹ _RoutePatternMatch//api/v1/const_|_not_match_|_/api/v1/something-16 152.0 ± 0% 152.0 ± 0% ~ (p=1.000 n=20) ¹ _RoutePatternMatch//api/:param/fixedEnd_|_match_|_/api/abc/fixedEnd-16 368.0 ± 0% 368.0 ± 0% ~ (p=1.000 n=20) ¹ _RoutePatternMatch//api/:param/fixedEnd_|_not_match_|_/api/abc/def/fixedEnd-16 368.0 ± 0% 368.0 ± 0% ~ (p=1.000 n=20) ¹ _RoutePatternMatch//api/v1/:param/*_|_match_|_/api/v1/entity-16 432.0 ± 0% 432.0 ± 0% ~ (p=1.000 n=20) ¹ _RoutePatternMatch//api/v1/:param/*_|_match_|_/api/v1/entity/-16 432.0 ± 0% 432.0 ± 0% ~ (p=1.000 n=20) ¹ _RoutePatternMatch//api/v1/:param/*_|_match_|_/api/v1/entity/1-16 432.0 ± 0% 432.0 ± 0% ~ (p=1.000 n=20) ¹ _RoutePatternMatch//api/v1/:param/*_|_not_match_|_/api/v-16 424.0 ± 0% 424.0 ± 0% ~ (p=1.000 n=20) ¹ _RoutePatternMatch//api/v1/:param/*_|_not_match_|_/api/v2-16 424.0 ± 0% 424.0 ± 0% ~ (p=1.000 n=20) ¹ _RoutePatternMatch//api/v1/:param/*_|_not_match_|_/api/v1/-16 424.0 ± 0% 424.0 ± 0% ~ (p=1.000 n=20) ¹ geomean 288.8 288.8 +0.00% ¹ all samples are equal │ old.txt │ new.txt │ │ allocs/op │ allocs/op vs base │ _RoutePatternMatch//api/v1/const_|_match_|_/api/v1/const-16 4.000 ± 0% 4.000 ± 0% ~ (p=1.000 n=20) ¹ _RoutePatternMatch//api/v1/const_|_not_match_|_/api/v1-16 4.000 ± 0% 4.000 ± 0% ~ (p=1.000 n=20) ¹ _RoutePatternMatch//api/v1/const_|_not_match_|_/api/v1/-16 4.000 ± 0% 4.000 ± 0% ~ (p=1.000 n=20) ¹ _RoutePatternMatch//api/v1/const_|_not_match_|_/api/v1/something-16 4.000 ± 0% 4.000 ± 0% ~ (p=1.000 n=20) ¹ _RoutePatternMatch//api/:param/fixedEnd_|_match_|_/api/abc/fixedEnd-16 9.000 ± 0% 9.000 ± 0% ~ (p=1.000 n=20) ¹ _RoutePatternMatch//api/:param/fixedEnd_|_not_match_|_/api/abc/def/fixedEnd-16 9.000 ± 0% 9.000 ± 0% ~ (p=1.000 n=20) ¹ _RoutePatternMatch//api/v1/:param/*_|_match_|_/api/v1/entity-16 9.000 ± 0% 9.000 ± 0% ~ (p=1.000 n=20) ¹ _RoutePatternMatch//api/v1/:param/*_|_match_|_/api/v1/entity/-16 9.000 ± 0% 9.000 ± 0% ~ (p=1.000 n=20) ¹ _RoutePatternMatch//api/v1/:param/*_|_match_|_/api/v1/entity/1-16 9.000 ± 0% 9.000 ± 0% ~ (p=1.000 n=20) ¹ _RoutePatternMatch//api/v1/:param/*_|_not_match_|_/api/v-16 9.000 ± 0% 9.000 ± 0% ~ (p=1.000 n=20) ¹ _RoutePatternMatch//api/v1/:param/*_|_not_match_|_/api/v2-16 9.000 ± 0% 9.000 ± 0% ~ (p=1.000 n=20) ¹ _RoutePatternMatch//api/v1/:param/*_|_not_match_|_/api/v1/-16 9.000 ± 0% 9.000 ± 0% ~ (p=1.000 n=20) ¹ geomean 6.868 6.868 +0.00% ¹ all samples are equal ``` --- path.go | 46 +++++++++++++++++++++++----------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/path.go b/path.go index b188a41c..8cfde73f 100644 --- a/path.go +++ b/path.go @@ -123,16 +123,6 @@ var ( parameterDelimiterChars = append([]byte{paramStarterChar, escapeChar}, routeDelimiter...) // list of chars to find the end of a parameter parameterEndChars = append([]byte{optionalParam}, parameterDelimiterChars...) - // list of parameter constraint start - parameterConstraintStartChars = []byte{paramConstraintStart} - // list of parameter constraint end - parameterConstraintEndChars = []byte{paramConstraintEnd} - // list of parameter separator - parameterConstraintSeparatorChars = []byte{paramConstraintSeparator} - // list of parameter constraint data start - parameterConstraintDataStartChars = []byte{paramConstraintDataStart} - // list of parameter constraint data separator - parameterConstraintDataSeparatorChars = []byte{paramConstraintDataSeparator} ) // RoutePatternMatch checks if a given path matches a Fiber route pattern. @@ -337,8 +327,8 @@ func (parser *routeParser) analyseParameterPart(pattern string, customConstraint // find constraint part if exists in the parameter part and remove it if parameterEndPosition > 0 { - parameterConstraintStart = findNextNonEscapedCharsetPosition(pattern[0:parameterEndPosition], parameterConstraintStartChars) - parameterConstraintEnd = strings.LastIndexByte(pattern[0:parameterEndPosition+1], paramConstraintEnd) + parameterConstraintStart = findNextNonEscapedCharPosition(pattern[:parameterEndPosition], paramConstraintStart) + parameterConstraintEnd = strings.LastIndexByte(pattern[:parameterEndPosition+1], paramConstraintEnd) } // cut params part @@ -351,11 +341,11 @@ func (parser *routeParser) analyseParameterPart(pattern string, customConstraint if hasConstraint := parameterConstraintStart != -1 && parameterConstraintEnd != -1; hasConstraint { constraintString := pattern[parameterConstraintStart+1 : parameterConstraintEnd] - userConstraints := splitNonEscaped(constraintString, string(parameterConstraintSeparatorChars)) + userConstraints := splitNonEscaped(constraintString, paramConstraintSeparator) constraints = make([]*Constraint, 0, len(userConstraints)) for _, c := range userConstraints { - start := findNextNonEscapedCharsetPosition(c, parameterConstraintDataStartChars) + start := findNextNonEscapedCharPosition(c, paramConstraintDataStart) end := strings.LastIndexByte(c, paramConstraintDataEnd) // Assign constraint @@ -368,7 +358,7 @@ func (parser *routeParser) analyseParameterPart(pattern string, customConstraint // remove escapes from data if constraint.ID != regexConstraint { - constraint.Data = splitNonEscaped(c[start+1:end], string(parameterConstraintDataSeparatorChars)) + constraint.Data = splitNonEscaped(c[start+1:end], paramConstraintDataSeparator) if len(constraint.Data) == 1 { constraint.Data[0] = RemoveEscapeChar(constraint.Data[0]) } else if len(constraint.Data) == 2 { // This is fine, we simply expect two parts @@ -432,11 +422,11 @@ func findNextCharsetPosition(search string, charset []byte) int { return nextPosition } -// findNextCharsetPositionConstraint search the next char position from the charset +// findNextCharsetPositionConstraint searches the next char position from the charset // unlike findNextCharsetPosition, it takes care of constraint start-end chars to parse route pattern func findNextCharsetPositionConstraint(search string, charset []byte) int { - constraintStart := findNextNonEscapedCharsetPosition(search, parameterConstraintStartChars) - constraintEnd := findNextNonEscapedCharsetPosition(search, parameterConstraintEndChars) + constraintStart := findNextNonEscapedCharPosition(search, paramConstraintStart) + constraintEnd := findNextNonEscapedCharPosition(search, paramConstraintEnd) nextPosition := -1 for _, char := range charset { @@ -452,7 +442,7 @@ func findNextCharsetPositionConstraint(search string, charset []byte) int { return nextPosition } -// findNextNonEscapedCharsetPosition search the next char position from the charset and skip the escaped characters +// findNextNonEscapedCharsetPosition searches the next char position from the charset and skips the escaped characters func findNextNonEscapedCharsetPosition(search string, charset []byte) int { pos := findNextCharsetPosition(search, charset) for pos > 0 && search[pos-1] == escapeChar { @@ -471,16 +461,26 @@ func findNextNonEscapedCharsetPosition(search string, charset []byte) int { return pos } +// findNextNonEscapedCharPosition searches the next char position and skips the escaped characters +func findNextNonEscapedCharPosition(search string, char byte) int { + for i := 0; i < len(search); i++ { + if search[i] == char && (i == 0 || search[i-1] != escapeChar) { + return i + } + } + return -1 +} + // splitNonEscaped slices s into all substrings separated by sep and returns a slice of the substrings between those separators // This function also takes a care of escape char when splitting. -func splitNonEscaped(s, sep string) []string { +func splitNonEscaped(s string, sep byte) []string { var result []string - i := findNextNonEscapedCharsetPosition(s, []byte(sep)) + i := findNextNonEscapedCharPosition(s, sep) for i > -1 { result = append(result, s[:i]) - s = s[i+len(sep):] - i = findNextNonEscapedCharsetPosition(s, []byte(sep)) + s = s[i+1:] + i = findNextNonEscapedCharPosition(s, sep) } return append(result, s) From d8f9548650f8629f5f86c54bffd8f9e29ef14b3b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 31 Mar 2025 12:07:59 +0000 Subject: [PATCH 4/7] build(deps): bump github.com/fxamacker/cbor/v2 from 2.7.0 to 2.8.0 Bumps [github.com/fxamacker/cbor/v2](https://github.com/fxamacker/cbor) from 2.7.0 to 2.8.0. - [Release notes](https://github.com/fxamacker/cbor/releases) - [Commits](https://github.com/fxamacker/cbor/compare/v2.7.0...v2.8.0) --- updated-dependencies: - dependency-name: github.com/fxamacker/cbor/v2 dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- go.mod | 4 +++- go.sum | 6 ++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index f4cd17c1..dea86813 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,8 @@ module github.com/gofiber/fiber/v3 go 1.23.0 +toolchain go1.24.1 + require ( github.com/gofiber/schema v1.3.0 github.com/gofiber/utils/v2 v2.0.0-beta.7 @@ -18,7 +20,7 @@ require ( require ( github.com/andybalholm/brotli v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/fxamacker/cbor/v2 v2.7.0 // direct + github.com/fxamacker/cbor/v2 v2.8.0 // direct github.com/klauspost/compress v1.17.11 // indirect github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/go.sum b/go.sum index 71824581..81b9a920 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7X github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E= -github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= +github.com/fxamacker/cbor/v2 v2.8.0 h1:fFtUGXUzXPHTIUdne5+zzMPTfffl3RD5qYnkY40vtxU= +github.com/fxamacker/cbor/v2 v2.8.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ= github.com/gofiber/schema v1.3.0 h1:K3F3wYzAY+aivfCCEHPufCthu5/13r/lzp1nuk6mr3Q= github.com/gofiber/schema v1.3.0/go.mod h1:YYwj01w3hVfaNjhtJzaqetymL56VW642YS3qZPhuE6c= github.com/gofiber/utils/v2 v2.0.0-beta.7 h1:NnHFrRHvhrufPABdWajcKZejz9HnCWmT/asoxRsiEbQ= @@ -34,8 +34,6 @@ github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZ github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= -golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= -golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c= golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= From 2f794d9f88898c998e6bb611c9ce6f4c15223bc8 Mon Sep 17 00:00:00 2001 From: Juan Calderon-Perez <835733+gaby@users.noreply.github.com> Date: Mon, 31 Mar 2025 08:14:18 -0400 Subject: [PATCH 5/7] Update go.mod --- go.mod | 1 - 1 file changed, 1 deletion(-) diff --git a/go.mod b/go.mod index dea86813..39f1a36d 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,6 @@ module github.com/gofiber/fiber/v3 go 1.23.0 -toolchain go1.24.1 require ( github.com/gofiber/schema v1.3.0 From dec28010e9cd6ba93c2e4439695229720ce0ea1e Mon Sep 17 00:00:00 2001 From: Juan Calderon-Perez <835733+gaby@users.noreply.github.com> Date: Mon, 31 Mar 2025 08:14:32 -0400 Subject: [PATCH 6/7] Update go.mod --- go.mod | 1 - 1 file changed, 1 deletion(-) diff --git a/go.mod b/go.mod index 39f1a36d..f5490bf1 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,6 @@ module github.com/gofiber/fiber/v3 go 1.23.0 - require ( github.com/gofiber/schema v1.3.0 github.com/gofiber/utils/v2 v2.0.0-beta.7 From bb12633c8ba8f085e9a568de49e505476d51b2d0 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Mon, 31 Mar 2025 11:55:01 -0300 Subject: [PATCH 7/7] =?UTF-8?q?Revert=20"=F0=9F=94=A5=20feat:=20Add=20supp?= =?UTF-8?q?ort=20for=20context.Context=20in=20keyauth=20middleware"=20(#33?= =?UTF-8?q?64)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Revert "🔥 feat: Add support for context.Context in keyauth middleware (#3287)" This reverts commit 4177ab4086a97648553f34bcff2ff81a137d31f3. Co-authored-by: Juan Calderon-Perez <835733+gaby@users.noreply.github.com> --- middleware/keyauth/keyauth.go | 22 ++------ middleware/keyauth/keyauth_test.go | 82 +++++++++--------------------- 2 files changed, 29 insertions(+), 75 deletions(-) diff --git a/middleware/keyauth/keyauth.go b/middleware/keyauth/keyauth.go index 54ecdbe5..e245ba42 100644 --- a/middleware/keyauth/keyauth.go +++ b/middleware/keyauth/keyauth.go @@ -2,7 +2,6 @@ package keyauth import ( - "context" "errors" "fmt" "net/url" @@ -60,10 +59,7 @@ func New(config ...Config) fiber.Handler { valid, err := cfg.Validator(c, key) if err == nil && valid { - // Store in both Locals and Context c.Locals(tokenKey, key) - ctx := context.WithValue(c.Context(), tokenKey, key) - c.SetContext(ctx) return cfg.SuccessHandler(c) } return cfg.ErrorHandler(c, err) @@ -72,20 +68,12 @@ func New(config ...Config) fiber.Handler { // TokenFromContext returns the bearer token from the request context. // returns an empty string if the token does not exist -func TokenFromContext(c any) string { - switch ctx := c.(type) { - case context.Context: - if token, ok := ctx.Value(tokenKey).(string); ok { - return token - } - case fiber.Ctx: - if token, ok := ctx.Locals(tokenKey).(string); ok { - return token - } - default: - panic("unsupported context type, expected fiber.Ctx or context.Context") +func TokenFromContext(c fiber.Ctx) string { + token, ok := c.Locals(tokenKey).(string) + if !ok { + return "" } - return "" + return token } // MultipleKeySourceLookup creates a CustomKeyLookup function that checks multiple sources until one is found diff --git a/middleware/keyauth/keyauth_test.go b/middleware/keyauth/keyauth_test.go index 27c4e5a0..72c9d3c1 100644 --- a/middleware/keyauth/keyauth_test.go +++ b/middleware/keyauth/keyauth_test.go @@ -503,67 +503,33 @@ func Test_TokenFromContext_None(t *testing.T) { } func Test_TokenFromContext(t *testing.T) { - // Test that TokenFromContext returns the correct token - t.Run("fiber.Ctx", func(t *testing.T) { - app := fiber.New() - app.Use(New(Config{ - KeyLookup: "header:Authorization", - AuthScheme: "Basic", - Validator: func(_ fiber.Ctx, key string) (bool, error) { - if key == CorrectKey { - return true, nil - } - return false, ErrMissingOrMalformedAPIKey - }, - })) - app.Get("/", func(c fiber.Ctx) error { - return c.SendString(TokenFromContext(c)) - }) - - req := httptest.NewRequest(fiber.MethodGet, "/", nil) - req.Header.Add("Authorization", "Basic "+CorrectKey) - res, err := app.Test(req) - require.NoError(t, err) - - body, err := io.ReadAll(res.Body) - require.NoError(t, err) - require.Equal(t, CorrectKey, string(body)) + app := fiber.New() + // Wire up keyauth middleware to set TokenFromContext now + app.Use(New(Config{ + KeyLookup: "header:Authorization", + AuthScheme: "Basic", + Validator: func(_ fiber.Ctx, key string) (bool, error) { + if key == CorrectKey { + return true, nil + } + return false, ErrMissingOrMalformedAPIKey + }, + })) + // Define a test handler that checks TokenFromContext + app.Get("/", func(c fiber.Ctx) error { + return c.SendString(TokenFromContext(c)) }) - t.Run("context.Context", func(t *testing.T) { - app := fiber.New() - app.Use(New(Config{ - KeyLookup: "header:Authorization", - AuthScheme: "Basic", - Validator: func(_ fiber.Ctx, key string) (bool, error) { - if key == CorrectKey { - return true, nil - } - return false, ErrMissingOrMalformedAPIKey - }, - })) - // Verify that TokenFromContext works with context.Context - app.Get("/", func(c fiber.Ctx) error { - ctx := c.Context() - token := TokenFromContext(ctx) - return c.SendString(token) - }) + req := httptest.NewRequest(fiber.MethodGet, "/", nil) + req.Header.Add("Authorization", "Basic "+CorrectKey) + // Send + res, err := app.Test(req) + require.NoError(t, err) - req := httptest.NewRequest(fiber.MethodGet, "/", nil) - req.Header.Add("Authorization", "Basic "+CorrectKey) - res, err := app.Test(req) - require.NoError(t, err) - - body, err := io.ReadAll(res.Body) - require.NoError(t, err) - require.Equal(t, CorrectKey, string(body)) - }) - - t.Run("invalid context type", func(t *testing.T) { - require.Panics(t, func() { - _ = TokenFromContext("invalid") - }) - }) + // Read the response body into a string + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, CorrectKey, string(body)) } func Test_AuthSchemeToken(t *testing.T) {