diff --git a/app.go b/app.go index 43ccaf9a..ae16a618 100644 --- a/app.go +++ b/app.go @@ -106,6 +106,8 @@ type App struct { tlsHandler *TLSHandler // Mount fields mountFields *mountFields + // state management + state *State // Route stack divided by HTTP methods stack [][]*Route // Route stack divided by HTTP methods and route prefixes @@ -527,6 +529,9 @@ func New(config ...Config) *App { // Define mountFields app.mountFields = newMountFields(app) + // Define state + app.state = newState() + // Override config if provided if len(config) > 0 { app.config = config[0] @@ -964,6 +969,11 @@ func (app *App) Hooks() *Hooks { return app.hooks } +// State returns the state struct to store global data in order to share it between handlers. +func (app *App) State() *State { + return app.state +} + var ErrTestGotEmptyResponse = errors.New("test: got empty response") // TestConfig is a struct holding Test settings diff --git a/app_test.go b/app_test.go index 84606e6b..b5d5ed46 100644 --- a/app_test.go +++ b/app_test.go @@ -1890,6 +1890,16 @@ func Test_Route_Naming_Issue_2671_2685(t *testing.T) { require.Equal(t, "/simple-route", sRoute2.Path) } +func Test_App_State(t *testing.T) { + t.Parallel() + app := New() + + app.State().Set("key", "value") + str, ok := app.State().GetString("key") + require.True(t, ok) + require.Equal(t, "value", str) +} + // go test -v -run=^$ -bench=Benchmark_Communication_Flow -benchmem -count=4 func Benchmark_Communication_Flow(b *testing.B) { app := New() diff --git a/docs/api/constants.md b/docs/api/constants.md index 53bbb25c..70366de1 100644 --- a/docs/api/constants.md +++ b/docs/api/constants.md @@ -2,7 +2,7 @@ id: constants title: 📋 Constants description: Some constants for Fiber. -sidebar_position: 8 +sidebar_position: 9 --- ### HTTP methods were copied from net/http diff --git a/docs/api/state.md b/docs/api/state.md new file mode 100644 index 00000000..b22b9675 --- /dev/null +++ b/docs/api/state.md @@ -0,0 +1,640 @@ +# State Management + +This document details the state management functionality provided by Fiber, a thread-safe global key–value store used to store application dependencies and runtime data. The implementation is based on Go's `sync.Map`, ensuring concurrency safety. + +Below is the detailed description of all public methods and usage examples. + +## State Type + +`State` is a key–value store built on top of `sync.Map`. It allows storage and retrieval of dependencies and configurations in a Fiber application as well as thread–safe access to runtime data. + +### Definition + +```go +// State is a key–value store for Fiber's app, used as a global storage for the app's dependencies. +// It is a thread–safe implementation of a map[string]any, using sync.Map. +type State struct { + dependencies sync.Map +} +``` + +## Methods on State + +### Set + +Set adds or updates a key–value pair in the State. + +```go +// Set adds or updates a key–value pair in the State. +func (s *State) Set(key string, value any) +``` + +**Usage Example:** + +```go +app.State().Set("appName", "My Fiber App") +``` + +### Get + +Get retrieves a value from the State. + +```go title="Signature" +func (s *State) Get(key string) (any, bool) +``` + +**Usage Example:** + +```go +value, ok := app.State().Get("appName") +if ok { + fmt.Println("App Name:", value) +} +``` + +### MustGet + +MustGet retrieves a value from the State and panics if the key is not found. + +```go title="Signature" +func (s *State) MustGet(key string) any +``` + +**Usage Example:** + +```go +appName := app.State().MustGet("appName") +fmt.Println("App Name:", appName) +``` + +### Has + +Has checks if a key exists in the State. + +```go title="Signature"s +func (s *State) Has(key string) bool +``` + +**Usage Example:** + +```go +if app.State().Has("appName") { + fmt.Println("App Name is set.") +} +``` + +### Delete + +Delete removes a key–value pair from the State. + +```go title="Signature" +func (s *State) Delete(key string) +``` + +**Usage Example:** + +```go +app.State().Delete("obsoleteKey") +``` + +### Reset + +Reset removes all keys from the State. + +```go title="Signature" +func (s *State) Reset() +``` + +**Usage Example:** + +```go +app.State().Reset() +``` + +### Keys + +Keys returns a slice containing all keys present in the State. + +```go title="Signature" +func (s *State) Keys() []string +``` + +**Usage Example:** + +```go +keys := app.State().Keys() +fmt.Println("State Keys:", keys) +``` + +### Len + +Len returns the number of keys in the State. + +```go +// Len returns the number of keys in the State. +func (s *State) Len() int +``` + +**Usage Example:** + +```go +fmt.Printf("Total State Entries: %d\n", app.State().Len()) +``` + +### GetString + +GetString retrieves a string value from the State. It returns the string and a boolean indicating a successful type assertion. + +```go title="Signature" +func (s *State) GetString(key string) (string, bool) +``` + +**Usage Example:** + +```go +if appName, ok := app.State().GetString("appName"); ok { + fmt.Println("App Name:", appName) +} +``` + +### GetInt + +GetInt retrieves an integer value from the State. It returns the int and a boolean indicating a successful type assertion. + +```go title="Signature" +func (s *State) GetInt(key string) (int, bool) +``` + +**Usage Example:** + +```go +if count, ok := app.State().GetInt("userCount"); ok { + fmt.Printf("User Count: %d\n", count) +} +``` + +### GetBool + +GetBool retrieves a boolean value from the State. It returns the bool and a boolean indicating a successful type assertion. + +```go title="Signature" +func (s *State) GetBool(key string) (value, bool) +``` + +**Usage Example:** + +```go +if debug, ok := app.State().GetBool("debugMode"); ok { + fmt.Printf("Debug Mode: %v\n", debug) +} +``` + +### GetFloat64 + +GetFloat64 retrieves a float64 value from the State. It returns the float64 and a boolean indicating a successful type assertion. + +```go title="Signature" +func (s *State) GetFloat64(key string) (float64, bool) +``` + +**Usage Example:** + +```go title="Signature" +if ratio, ok := app.State().GetFloat64("scalingFactor"); ok { + fmt.Printf("Scaling Factor: %f\n", ratio) +} +``` + +### GetUint + +GetUint retrieves a `uint` value from the State. + +```go title="Signature" +func (s *State) GetUint(key string) (uint, bool) +``` + +**Usage Example:** + +```go +if val, ok := app.State().GetUint("maxConnections"); ok { + fmt.Printf("Max Connections: %d\n", val) +} +``` + +### GetInt8 + +GetInt8 retrieves an `int8` value from the State. + +```go title="Signature" +func (s *State) GetInt8(key string) (int8, bool) +``` + +**Usage Example:** + +```go +if val, ok := app.State().GetInt8("threshold"); ok { + fmt.Printf("Threshold: %d\n", val) +} +``` + +### GetInt16 + +GetInt16 retrieves an `int16` value from the State. + +```go title="Signature" +func (s *State) GetInt16(key string) (int16, bool) +``` + +**Usage Example:** + +```go +if val, ok := app.State().GetInt16("minValue"); ok { + fmt.Printf("Minimum Value: %d\n", val) +} +``` + +### GetInt32 + +GetInt32 retrieves an `int32` value from the State. + +```go title="Signature" +func (s *State) GetInt32(key string) (int32, bool) +``` + +**Usage Example:** + +```go +if val, ok := app.State().GetInt32("portNumber"); ok { + fmt.Printf("Port Number: %d\n", val) +} +``` + +### GetInt64 + +GetInt64 retrieves an `int64` value from the State. + +```go title="Signature" +func (s *State) GetInt64(key string) (int64, bool) +``` + +**Usage Example:** + +```go +if val, ok := app.State().GetInt64("fileSize"); ok { + fmt.Printf("File Size: %d\n", val) +} +``` + +### GetUint8 + +GetUint8 retrieves a `uint8` value from the State. + +```go title="Signature" +func (s *State) GetUint8(key string) (uint8, bool) +``` + +**Usage Example:** + +```go +if val, ok := app.State().GetUint8("byteValue"); ok { + fmt.Printf("Byte Value: %d\n", val) +} +``` + +### GetUint16 + +GetUint16 retrieves a `uint16` value from the State. + +```go title="Signature" +func (s *State) GetUint16(key string) (uint16, bool) +``` + +**Usage Example:** + +```go +if val, ok := app.State().GetUint16("limit"); ok { + fmt.Printf("Limit: %d\n", val) +} +``` + +### GetUint32 + +GetUint32 retrieves a `uint32` value from the State. + +```go title="Signature" +func (s *State) GetUint32(key string) (uint32, bool) +``` + +**Usage Example:** + +```go +if val, ok := app.State().GetUint32("timeout"); ok { + fmt.Printf("Timeout: %d\n", val) +} +``` + +### GetUint64 + +GetUint64 retrieves a `uint64` value from the State. + +```go title="Signature" +func (s *State) GetUint64(key string) (uint64, bool) +``` + +**Usage Example:** + +```go +if val, ok := app.State().GetUint64("maxSize"); ok { + fmt.Printf("Max Size: %d\n", val) +} +``` + +### GetUintptr + +GetUintptr retrieves a `uintptr` value from the State. + +```go title="Signature" +func (s *State) GetUintptr(key string) (uintptr, bool) +``` + +**Usage Example:** + +```go +if val, ok := app.State().GetUintptr("pointerValue"); ok { + fmt.Printf("Pointer Value: %d\n", val) +} +``` + +### GetFloat32 + +GetFloat32 retrieves a `float32` value from the State. + +```go title="Signature" +func (s *State) GetFloat32(key string) (float32, bool) +``` + +**Usage Example:** + +```go +if val, ok := app.State().GetFloat32("scalingFactor32"); ok { + fmt.Printf("Scaling Factor (float32): %f\n", val) +} +``` + +### GetComplex64 + +GetComplex64 retrieves a `complex64` value from the State. + +```go title="Signature" +func (s *State) GetComplex64(key string) (complex64, bool) +``` + +**Usage Example:** + +```go +if val, ok := app.State().GetComplex64("complexVal"); ok { + fmt.Printf("Complex Value (complex64): %v\n", val) +} +``` + +### GetComplex128 + +GetComplex128 retrieves a `complex128` value from the State. + +```go title="Signature" +func (s *State) GetComplex128(key string) (complex128, bool) +``` + +**Usage Example:** + +```go +if val, ok := app.State().GetComplex128("complexVal128"); ok { + fmt.Printf("Complex Value (complex128): %v\n", val) +} +``` + +## Generic Functions + +Fiber provides generic functions to retrieve state values with type safety and fallback options. + +### GetState + +GetState retrieves a value from the State and casts it to the desired type. It returns the cast value and a boolean indicating if the cast was successful. + +```go title="Signature" +func GetState[T any](s *State, key string) (T, bool) +``` + +**Usage Example:** + +```go +// Retrieve an integer value safely. +userCount, ok := GetState[int](app.State(), "userCount") +if ok { + fmt.Printf("User Count: %d\n", userCount) +} +``` + +### MustGetState + +MustGetState retrieves a value from the State and casts it to the desired type. It panics if the key is not found or if the type assertion fails. + +```go title="Signature" +func MustGetState[T any](s *State, key string) T +``` + +**Usage Example:** + +```go +// Retrieve the value or panic if it is not present. +config := MustGetState[string](app.State(), "configFile") +fmt.Println("Config File:", config) +``` + +### GetStateWithDefault + +GetStateWithDefault retrieves a value from the State, casting it to the desired type. If the key is not present, it returns the provided default value. + +```go title="Signature" +func GetStateWithDefault[T any](s *State, key string, defaultVal T) T +``` + +**Usage Example:** + +```go +// Retrieve a value with a default fallback. +requestCount := GetStateWithDefault[int](app.State(), "requestCount", 0) +fmt.Printf("Request Count: %d\n", requestCount) +``` + +## Comprehensive Examples + +### Example: Request Counter + +This example demonstrates how to track the number of requests using the State. + +```go +package main + +import ( + "fmt" + + "github.com/gofiber/fiber/v3" +) + +func main() { + app := fiber.New() + + // Initialize state with a counter. + app.State().Set("requestCount", 0) + + // Middleware: Increase counter for every request. + app.Use(func(c fiber.Ctx) error { + count, _ := c.App().State().GetInt("requestCount") + app.State().Set("requestCount", count+1) + return c.Next() + }) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("Hello World!") + }) + + app.Get("/stats", func(c fiber.Ctx) error { + count, _ := c.App().State().Get("requestCount") + return c.SendString(fmt.Sprintf("Total requests: %d", count)) + }) + + app.Listen(":3000") +} +``` + +### Example: Environment–Specific Configuration + +This example shows how to configure different settings based on the environment. + +```go +package main + +import ( + "os" + + "github.com/gofiber/fiber/v3" +) + +func main() { + app := fiber.New() + + // Determine environment. + environment := os.Getenv("ENV") + if environment == "" { + environment = "development" + } + app.State().Set("environment", environment) + + // Set environment-specific configurations. + if environment == "development" { + app.State().Set("apiUrl", "http://localhost:8080/api") + app.State().Set("debug", true) + } else { + app.State().Set("apiUrl", "https://api.production.com") + app.State().Set("debug", false) + } + + app.Get("/config", func(c fiber.Ctx) error { + config := map[string]any{ + "environment": environment, + "apiUrl": fiber.GetStateWithDefault(c.App().State(), "apiUrl", ""), + "debug": fiber.GetStateWithDefault(c.App().State(), "debug", false), + } + return c.JSON(config) + }) + + app.Listen(":3000") +} +``` + +### Example: Dependency Injection with State Management + +This example demonstrates how to use the State for dependency injection in a Fiber application. + +```go +package main + +import ( + "context" + "fmt" + "log" + + "github.com/gofiber/fiber/v3" + "github.com/redis/go-redis/v9" +) + +type User struct { + ID int `query:"id"` + Name string `query:"name"` + Email string `query:"email"` +} + +func main() { + app := fiber.New() + ctx := context.Background() + + // Initialize Redis client. + rdb := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + Password: "", + DB: 0, + }) + + // Check the Redis connection. + if err := rdb.Ping(ctx).Err(); err != nil { + log.Fatalf("Could not connect to Redis: %v", err) + } + + // Inject the Redis client into Fiber's State for dependency injection. + app.State().Set("redis", rdb) + + app.Get("/user/create", func(c fiber.Ctx) error { + var user User + if err := c.Bind().Query(&user); err != nil { + return c.Status(fiber.StatusBadRequest).SendString(err.Error()) + } + + // Save the user to the database. + rdb, ok := fiber.GetState[*redis.Client](c.App().State(), "redis") + if !ok { + return c.Status(fiber.StatusInternalServerError).SendString("Redis client not found") + } + + // Save the user to the database. + key := fmt.Sprintf("user:%d", user.ID) + err := rdb.HSet(ctx, key, "name", user.Name, "email", user.Email).Err() + if err != nil { + return c.Status(fiber.StatusInternalServerError).SendString(err.Error()) + } + + return c.JSON(user) + }) + + app.Get("/user/:id", func(c fiber.Ctx) error { + id := c.Params("id") + + rdb, ok := fiber.GetState[*redis.Client](c.App().State(), "redis") + if !ok { + return c.Status(fiber.StatusInternalServerError).SendString("Redis client not found") + } + + key := fmt.Sprintf("user:%s", id) + user, err := rdb.HGetAll(ctx, key).Result() + if err == redis.Nil { + return c.Status(fiber.StatusNotFound).SendString("User not found") + } else if err != nil { + return c.Status(fiber.StatusInternalServerError).SendString(err.Error()) + } + + return c.JSON(user) + }) + + app.Listen(":3000") +} +``` diff --git a/docs/whats_new.md b/docs/whats_new.md index 19f261ae..4bf28927 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -59,6 +59,7 @@ We have made several changes to the Fiber app, including: - **RegisterCustomBinder**: Allows for the registration of custom binders. - **RegisterCustomConstraint**: Allows for the registration of custom constraints. - **NewCtxFunc**: Introduces a new context function. +- **State**: Provides a global state for the application, which can be used to store and retrieve data across the application. Check out the [State](./api/state) method for further details. ### Removed Methods diff --git a/state.go b/state.go new file mode 100644 index 00000000..43687603 --- /dev/null +++ b/state.go @@ -0,0 +1,322 @@ +package fiber + +import ( + "sync" +) + +// State is a key-value store for Fiber's app in order to be used as a global storage for the app's dependencies. +// It's a thread-safe implementation of a map[string]any, using sync.Map. +type State struct { + dependencies sync.Map +} + +// NewState creates a new instance of State. +func newState() *State { + return &State{ + dependencies: sync.Map{}, + } +} + +// Set sets a key-value pair in the State. +func (s *State) Set(key string, value any) { + s.dependencies.Store(key, value) +} + +// Get retrieves a value from the State. +func (s *State) Get(key string) (any, bool) { + return s.dependencies.Load(key) +} + +// MustGet retrieves a value from the State and panics if the key is not found. +func (s *State) MustGet(key string) any { + if dep, ok := s.Get(key); ok { + return dep + } + + panic("state: dependency not found!") +} + +// Has checks if a key is present in the State. +// It returns a boolean indicating if the key is present. +func (s *State) Has(key string) bool { + _, ok := s.Get(key) + return ok +} + +// Delete removes a key-value pair from the State. +func (s *State) Delete(key string) { + s.dependencies.Delete(key) +} + +// Reset resets the State by removing all keys. +func (s *State) Reset() { + s.dependencies.Clear() +} + +// Keys returns a slice containing all keys present in the State. +func (s *State) Keys() []string { + keys := make([]string, 0) + s.dependencies.Range(func(key, _ any) bool { + keyStr, ok := key.(string) + if !ok { + return false + } + + keys = append(keys, keyStr) + return true + }) + + return keys +} + +// Len returns the number of keys in the State. +func (s *State) Len() int { + length := 0 + s.dependencies.Range(func(_, _ any) bool { + length++ + return true + }) + + return length +} + +// GetState retrieves a value from the State and casts it to the desired type. +// It returns the casted value and a boolean indicating if the cast was successful. +func GetState[T any](s *State, key string) (T, bool) { + dep, ok := s.Get(key) + + if ok { + depT, okCast := dep.(T) + return depT, okCast + } + + var zeroVal T + return zeroVal, false +} + +// MustGetState retrieves a value from the State and casts it to the desired type. +// It panics if the key is not found or if the type assertion fails. +func MustGetState[T any](s *State, key string) T { + dep, ok := GetState[T](s, key) + if !ok { + panic("state: dependency not found!") + } + + return dep +} + +// GetStateWithDefault retrieves a value from the State, +// casting it to the desired type. If the key is not present, +// it returns the provided default value. +func GetStateWithDefault[T any](s *State, key string, defaultVal T) T { + dep, ok := GetState[T](s, key) + if !ok { + return defaultVal + } + + return dep +} + +// GetString retrieves a string value from the State. +// It returns the string and a boolean indicating successful type assertion. +func (s *State) GetString(key string) (string, bool) { + dep, ok := s.Get(key) + if ok { + depString, okCast := dep.(string) + return depString, okCast + } + + return "", false +} + +// GetInt retrieves an integer value from the State. +// It returns the int and a boolean indicating successful type assertion. +func (s *State) GetInt(key string) (int, bool) { + dep, ok := s.Get(key) + if ok { + depInt, okCast := dep.(int) + return depInt, okCast + } + + return 0, false +} + +// GetBool retrieves a boolean value from the State. +// It returns the bool and a boolean indicating successful type assertion. +func (s *State) GetBool(key string) (value, ok bool) { //nolint:nonamedreturns // Better idea to use named returns here + dep, ok := s.Get(key) + if ok { + depBool, okCast := dep.(bool) + return depBool, okCast + } + + return false, false +} + +// GetFloat64 retrieves a float64 value from the State. +// It returns the float64 and a boolean indicating successful type assertion. +func (s *State) GetFloat64(key string) (float64, bool) { + dep, ok := s.Get(key) + if ok { + depFloat64, okCast := dep.(float64) + return depFloat64, okCast + } + + return 0, false +} + +// GetUint retrieves a uint value from the State. +// It returns the float64 and a boolean indicating successful type assertion. +func (s *State) GetUint(key string) (uint, bool) { + dep, ok := s.Get(key) + if ok { + if depUint, okCast := dep.(uint); okCast { + return depUint, true + } + } + return 0, false +} + +// GetInt8 retrieves an int8 value from the State. +// It returns the float64 and a boolean indicating successful type assertion. +func (s *State) GetInt8(key string) (int8, bool) { + dep, ok := s.Get(key) + if ok { + if depInt8, okCast := dep.(int8); okCast { + return depInt8, true + } + } + return 0, false +} + +// GetInt16 retrieves an int16 value from the State. +// It returns the float64 and a boolean indicating successful type assertion. +func (s *State) GetInt16(key string) (int16, bool) { + dep, ok := s.Get(key) + if ok { + if depInt16, okCast := dep.(int16); okCast { + return depInt16, true + } + } + return 0, false +} + +// GetInt32 retrieves an int32 value from the State. +// It returns the float64 and a boolean indicating successful type assertion. +func (s *State) GetInt32(key string) (int32, bool) { + dep, ok := s.Get(key) + if ok { + if depInt32, okCast := dep.(int32); okCast { + return depInt32, true + } + } + return 0, false +} + +// GetInt64 retrieves an int64 value from the State. +// It returns the float64 and a boolean indicating successful type assertion. +func (s *State) GetInt64(key string) (int64, bool) { + dep, ok := s.Get(key) + if ok { + if depInt64, okCast := dep.(int64); okCast { + return depInt64, true + } + } + return 0, false +} + +// GetUint8 retrieves a uint8 value from the State. +// It returns the float64 and a boolean indicating successful type assertion. +func (s *State) GetUint8(key string) (uint8, bool) { + dep, ok := s.Get(key) + if ok { + if depUint8, okCast := dep.(uint8); okCast { + return depUint8, true + } + } + return 0, false +} + +// GetUint16 retrieves a uint16 value from the State. +// It returns the float64 and a boolean indicating successful type assertion. +func (s *State) GetUint16(key string) (uint16, bool) { + dep, ok := s.Get(key) + if ok { + if depUint16, okCast := dep.(uint16); okCast { + return depUint16, true + } + } + return 0, false +} + +// GetUint32 retrieves a uint32 value from the State. +// It returns the float64 and a boolean indicating successful type assertion. +func (s *State) GetUint32(key string) (uint32, bool) { + dep, ok := s.Get(key) + if ok { + if depUint32, okCast := dep.(uint32); okCast { + return depUint32, true + } + } + return 0, false +} + +// GetUint64 retrieves a uint64 value from the State. +// It returns the float64 and a boolean indicating successful type assertion. +func (s *State) GetUint64(key string) (uint64, bool) { + dep, ok := s.Get(key) + if ok { + if depUint64, okCast := dep.(uint64); okCast { + return depUint64, true + } + } + return 0, false +} + +// GetUintptr retrieves a uintptr value from the State. +// It returns the float64 and a boolean indicating successful type assertion. +func (s *State) GetUintptr(key string) (uintptr, bool) { + dep, ok := s.Get(key) + if ok { + if depUintptr, okCast := dep.(uintptr); okCast { + return depUintptr, true + } + } + return 0, false +} + +// GetFloat32 retrieves a float32 value from the State. +// It returns the float64 and a boolean indicating successful type assertion. +func (s *State) GetFloat32(key string) (float32, bool) { + dep, ok := s.Get(key) + if ok { + if depFloat32, okCast := dep.(float32); okCast { + return depFloat32, true + } + } + return 0, false +} + +// GetComplex64 retrieves a complex64 value from the State. +// It returns the float64 and a boolean indicating successful type assertion. +func (s *State) GetComplex64(key string) (complex64, bool) { + dep, ok := s.Get(key) + if ok { + if depComplex64, okCast := dep.(complex64); okCast { + return depComplex64, true + } + } + return 0, false +} + +// GetComplex128 retrieves a complex128 value from the State. +// It returns the float64 and a boolean indicating successful type assertion. +func (s *State) GetComplex128(key string) (complex128, bool) { + dep, ok := s.Get(key) + if ok { + if depComplex128, okCast := dep.(complex128); okCast { + return depComplex128, true + } + } + return 0, false +} diff --git a/state_test.go b/state_test.go new file mode 100644 index 00000000..e96ea0dd --- /dev/null +++ b/state_test.go @@ -0,0 +1,981 @@ +package fiber + +import ( + "strconv" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestState_SetAndGet_WithApp(t *testing.T) { + t.Parallel() + // Create app + app := New() + + // test setting and getting a value + app.State().Set("foo", "bar") + val, ok := app.State().Get("foo") + require.True(t, ok) + require.Equal(t, "bar", val) + + // test key not found + _, ok = app.State().Get("unknown") + require.False(t, ok) +} + +func TestState_SetAndGet(t *testing.T) { + t.Parallel() + st := newState() + + // test setting and getting a value + st.Set("foo", "bar") + val, ok := st.Get("foo") + require.True(t, ok) + require.Equal(t, "bar", val) + + // test key not found + _, ok = st.Get("unknown") + require.False(t, ok) +} + +func TestState_GetString(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("str", "hello") + s, ok := st.GetString("str") + require.True(t, ok) + require.Equal(t, "hello", s) + + // wrong type should return false + st.Set("num", 123) + s, ok = st.GetString("num") + require.False(t, ok) + require.Equal(t, "", s) + + // missing key should return false + s, ok = st.GetString("missing") + require.False(t, ok) + require.Equal(t, "", s) +} + +func TestState_GetInt(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("num", 456) + i, ok := st.GetInt("num") + require.True(t, ok) + require.Equal(t, 456, i) + + // wrong type should return zero value + st.Set("str", "abc") + i, ok = st.GetInt("str") + require.False(t, ok) + require.Equal(t, 0, i) + + // missing key should return zero value + i, ok = st.GetInt("missing") + require.False(t, ok) + require.Equal(t, 0, i) +} + +func TestState_GetBool(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("flag", true) + b, ok := st.GetBool("flag") + require.True(t, ok) + require.True(t, b) + + // wrong type + st.Set("num", 1) + b, ok = st.GetBool("num") + require.False(t, ok) + require.False(t, b) + + // missing key should return false + b, ok = st.GetBool("missing") + require.False(t, ok) + require.False(t, b) +} + +func TestState_GetFloat64(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("pi", 3.14) + f, ok := st.GetFloat64("pi") + require.True(t, ok) + require.InDelta(t, 3.14, f, 0.0001) + + // wrong type should return zero value + st.Set("int", 10) + f, ok = st.GetFloat64("int") + require.False(t, ok) + require.InDelta(t, 0.0, f, 0.0001) + + // missing key should return zero value + f, ok = st.GetFloat64("missing") + require.False(t, ok) + require.InDelta(t, 0.0, f, 0.0001) +} + +func TestState_GetUint(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("uint", uint(100)) + u, ok := st.GetUint("uint") + require.True(t, ok) + require.Equal(t, uint(100), u) + + st.Set("wrong", "not uint") + u, ok = st.GetUint("wrong") + require.False(t, ok) + require.Equal(t, uint(0), u) + + u, ok = st.GetUint("missing") + require.False(t, ok) + require.Equal(t, uint(0), u) +} + +func TestState_GetInt8(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("int8", int8(10)) + i, ok := st.GetInt8("int8") + require.True(t, ok) + require.Equal(t, int8(10), i) + + st.Set("wrong", "not int8") + i, ok = st.GetInt8("wrong") + require.False(t, ok) + require.Equal(t, int8(0), i) + + i, ok = st.GetInt8("missing") + require.False(t, ok) + require.Equal(t, int8(0), i) +} + +func TestState_GetInt16(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("int16", int16(200)) + i, ok := st.GetInt16("int16") + require.True(t, ok) + require.Equal(t, int16(200), i) + + st.Set("wrong", "not int16") + i, ok = st.GetInt16("wrong") + require.False(t, ok) + require.Equal(t, int16(0), i) + + i, ok = st.GetInt16("missing") + require.False(t, ok) + require.Equal(t, int16(0), i) +} + +func TestState_GetInt32(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("int32", int32(3000)) + i, ok := st.GetInt32("int32") + require.True(t, ok) + require.Equal(t, int32(3000), i) + + st.Set("wrong", "not int32") + i, ok = st.GetInt32("wrong") + require.False(t, ok) + require.Equal(t, int32(0), i) + + i, ok = st.GetInt32("missing") + require.False(t, ok) + require.Equal(t, int32(0), i) +} + +func TestState_GetInt64(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("int64", int64(4000)) + i, ok := st.GetInt64("int64") + require.True(t, ok) + require.Equal(t, int64(4000), i) + + st.Set("wrong", "not int64") + i, ok = st.GetInt64("wrong") + require.False(t, ok) + require.Equal(t, int64(0), i) + + i, ok = st.GetInt64("missing") + require.False(t, ok) + require.Equal(t, int64(0), i) +} + +func TestState_GetUint8(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("uint8", uint8(20)) + u, ok := st.GetUint8("uint8") + require.True(t, ok) + require.Equal(t, uint8(20), u) + + st.Set("wrong", "not uint8") + u, ok = st.GetUint8("wrong") + require.False(t, ok) + require.Equal(t, uint8(0), u) + + u, ok = st.GetUint8("missing") + require.False(t, ok) + require.Equal(t, uint8(0), u) +} + +func TestState_GetUint16(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("uint16", uint16(300)) + u, ok := st.GetUint16("uint16") + require.True(t, ok) + require.Equal(t, uint16(300), u) + + st.Set("wrong", "not uint16") + u, ok = st.GetUint16("wrong") + require.False(t, ok) + require.Equal(t, uint16(0), u) + + u, ok = st.GetUint16("missing") + require.False(t, ok) + require.Equal(t, uint16(0), u) +} + +func TestState_GetUint32(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("uint32", uint32(400000)) + u, ok := st.GetUint32("uint32") + require.True(t, ok) + require.Equal(t, uint32(400000), u) + + st.Set("wrong", "not uint32") + u, ok = st.GetUint32("wrong") + require.False(t, ok) + require.Equal(t, uint32(0), u) + + u, ok = st.GetUint32("missing") + require.False(t, ok) + require.Equal(t, uint32(0), u) +} + +func TestState_GetUint64(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("uint64", uint64(5000000)) + u, ok := st.GetUint64("uint64") + require.True(t, ok) + require.Equal(t, uint64(5000000), u) + + st.Set("wrong", "not uint64") + u, ok = st.GetUint64("wrong") + require.False(t, ok) + require.Equal(t, uint64(0), u) + + u, ok = st.GetUint64("missing") + require.False(t, ok) + require.Equal(t, uint64(0), u) +} + +func TestState_GetUintptr(t *testing.T) { + t.Parallel() + st := newState() + + var ptr uintptr = 12345 + st.Set("uintptr", ptr) + u, ok := st.GetUintptr("uintptr") + require.True(t, ok) + require.Equal(t, ptr, u) + + st.Set("wrong", "not uintptr") + u, ok = st.GetUintptr("wrong") + require.False(t, ok) + require.Equal(t, uintptr(0), u) + + u, ok = st.GetUintptr("missing") + require.False(t, ok) + require.Equal(t, uintptr(0), u) +} + +func TestState_GetFloat32(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("float32", float32(3.14)) + f, ok := st.GetFloat32("float32") + require.True(t, ok) + require.InDelta(t, float32(3.14), f, 0.0001) + + st.Set("wrong", "not float32") + f, ok = st.GetFloat32("wrong") + require.False(t, ok) + require.InDelta(t, float32(0), f, 0.0001) + + f, ok = st.GetFloat32("missing") + require.False(t, ok) + require.InDelta(t, float32(0), f, 0.0001) +} + +func TestState_GetComplex64(t *testing.T) { + t.Parallel() + st := newState() + + var c complex64 = complex(2, 3) + st.Set("complex64", c) + cRes, ok := st.GetComplex64("complex64") + require.True(t, ok) + require.Equal(t, c, cRes) + + st.Set("wrong", "not complex64") + cRes, ok = st.GetComplex64("wrong") + require.False(t, ok) + require.Equal(t, complex64(0), cRes) + + cRes, ok = st.GetComplex64("missing") + require.False(t, ok) + require.Equal(t, complex64(0), cRes) +} + +func TestState_GetComplex128(t *testing.T) { + t.Parallel() + st := newState() + + c := complex(4, 5) + st.Set("complex128", c) + cRes, ok := st.GetComplex128("complex128") + require.True(t, ok) + require.Equal(t, c, cRes) + + st.Set("wrong", "not complex128") + cRes, ok = st.GetComplex128("wrong") + require.False(t, ok) + require.Equal(t, complex128(0), cRes) + + cRes, ok = st.GetComplex128("missing") + require.False(t, ok) + require.Equal(t, complex128(0), cRes) +} + +func TestState_MustGet(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("exists", "value") + val := st.MustGet("exists") + require.Equal(t, "value", val) + + // must-get on missing key should panic + require.Panics(t, func() { + _ = st.MustGet("missing") + }) +} + +func TestState_Has(t *testing.T) { + t.Parallel() + + st := newState() + + st.Set("key", "value") + require.True(t, st.Has("key")) +} + +func TestState_Delete(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("key", "value") + st.Delete("key") + _, ok := st.Get("key") + require.False(t, ok) +} + +func TestState_Reset(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("a", 1) + st.Set("b", 2) + st.Reset() + require.Equal(t, 0, st.Len()) + require.Empty(t, st.Keys()) +} + +func TestState_Keys(t *testing.T) { + t.Parallel() + st := newState() + + keys := []string{"one", "two", "three"} + for _, k := range keys { + st.Set(k, k) + } + + returnedKeys := st.Keys() + require.ElementsMatch(t, keys, returnedKeys) +} + +func TestState_Len(t *testing.T) { + t.Parallel() + st := newState() + + require.Equal(t, 0, st.Len()) + + st.Set("a", "a") + require.Equal(t, 1, st.Len()) + + st.Set("b", "b") + require.Equal(t, 2, st.Len()) + + st.Delete("a") + require.Equal(t, 1, st.Len()) +} + +type testCase[T any] struct { //nolint:govet // It does not really matter for test + name string + key string + value any + expected T + ok bool +} + +func runGenericTest[T any](t *testing.T, getter func(*State, string) (T, bool), tests []testCase[T]) { + t.Helper() + + st := newState() + for _, tc := range tests { + st.Set(tc.key, tc.value) + got, ok := getter(st, tc.key) + require.Equal(t, tc.ok, ok, tc.name) + require.Equal(t, tc.expected, got, tc.name) + } +} + +func TestState_GetGeneric(t *testing.T) { + t.Parallel() + + runGenericTest[int](t, GetState[int], []testCase[int]{ + {"int correct conversion", "num", 42, 42, true}, + {"int wrong conversion from string", "str", "abc", 0, false}, + }) + + runGenericTest[string](t, GetState[string], []testCase[string]{ + {"string correct conversion", "strVal", "hello", "hello", true}, + {"string wrong conversion from int", "intVal", 100, "", false}, + }) + + runGenericTest[bool](t, GetState[bool], []testCase[bool]{ + {"bool correct conversion", "flag", true, true, true}, + {"bool wrong conversion from int", "intFlag", 1, false, false}, + }) + + runGenericTest[float64](t, GetState[float64], []testCase[float64]{ + {"float64 correct conversion", "pi", 3.14, 3.14, true}, + {"float64 wrong conversion from int", "intVal", 10, 0.0, false}, + }) +} + +func Test_MustGetStateGeneric(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("flag", true) + flag := MustGetState[bool](st, "flag") + require.True(t, flag) + + // mismatched type should panic + require.Panics(t, func() { + _ = MustGetState[string](st, "flag") + }) + + // missing key should also panic + require.Panics(t, func() { + _ = MustGetState[string](st, "missing") + }) +} + +func Test_GetStateWithDefault(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("flag", true) + flag := GetStateWithDefault(st, "flag", false) + require.True(t, flag) + + // mismatched type should return the default value + str := GetStateWithDefault(st, "flag", "default") + require.Equal(t, "default", str) + + // missing key should return the default value + flag = GetStateWithDefault(st, "missing", false) + require.False(t, flag) +} + +func BenchmarkState_Set(b *testing.B) { + b.ReportAllocs() + + st := newState() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, i) + } +} + +func BenchmarkState_Get(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + // pre-populate the state + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.Get(key) + } +} + +func BenchmarkState_GetString(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + // pre-populate the state + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, strconv.Itoa(i)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetString(key) + } +} + +func BenchmarkState_GetInt(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + // pre-populate the state + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetInt(key) + } +} + +func BenchmarkState_GetBool(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + // pre-populate the state + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, i%2 == 0) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetBool(key) + } +} + +func BenchmarkState_GetFloat64(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + // pre-populate the state + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, float64(i)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetFloat64(key) + } +} + +func BenchmarkState_MustGet(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + // pre-populate the state + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.MustGet(key) + } +} + +func BenchmarkState_GetStateGeneric(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + // pre-populate the state + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + GetState[int](st, key) + } +} + +func BenchmarkState_MustGetStateGeneric(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + // pre-populate the state + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + MustGetState[int](st, key) + } +} + +func BenchmarkState_GetStateWithDefault(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + // pre-populate the state + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + GetStateWithDefault[int](st, key, 0) + } +} + +func BenchmarkState_Has(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + // pre-populate the state + for i := 0; i < n; i++ { + st.Set("key"+strconv.Itoa(i), i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + st.Has("key" + strconv.Itoa(i%n)) + } +} + +func BenchmarkState_Delete(b *testing.B) { + b.ReportAllocs() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + st := newState() + st.Set("a", 1) + st.Delete("a") + } +} + +func BenchmarkState_Reset(b *testing.B) { + b.ReportAllocs() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + st := newState() + // add a fixed number of keys before clearing + for j := 0; j < 100; j++ { + st.Set("key"+strconv.Itoa(j), j) + } + st.Reset() + } +} + +func BenchmarkState_Keys(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + for i := 0; i < n; i++ { + st.Set("key"+strconv.Itoa(i), i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = st.Keys() + } +} + +func BenchmarkState_Len(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + for i := 0; i < n; i++ { + st.Set("key"+strconv.Itoa(i), i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = st.Len() + } +} + +func BenchmarkState_GetUint(b *testing.B) { + b.ReportAllocs() + st := newState() + n := 1000 + // Pre-populate the state with uint values. + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, uint(i)) //nolint:gosec // This is a test + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetUint(key) + } +} + +func BenchmarkState_GetInt8(b *testing.B) { + b.ReportAllocs() + st := newState() + n := 1000 + // Pre-populate the state with int8 values (using modulo to stay in range). + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, int8(i%128)) //nolint:gosec // This is a test + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetInt8(key) + } +} + +func BenchmarkState_GetInt16(b *testing.B) { + b.ReportAllocs() + st := newState() + n := 1000 + // Pre-populate the state with int16 values. + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, int16(i)) //nolint:gosec // This is a test + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetInt16(key) + } +} + +func BenchmarkState_GetInt32(b *testing.B) { + b.ReportAllocs() + st := newState() + n := 1000 + // Pre-populate the state with int32 values. + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, int32(i)) //nolint:gosec // This is a test + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetInt32(key) + } +} + +func BenchmarkState_GetInt64(b *testing.B) { + b.ReportAllocs() + st := newState() + n := 1000 + // Pre-populate the state with int64 values. + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, int64(i)) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetInt64(key) + } +} + +func BenchmarkState_GetUint8(b *testing.B) { + b.ReportAllocs() + st := newState() + n := 1000 + // Pre-populate the state with uint8 values. + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, uint8(i%256)) //nolint:gosec // This is a test + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetUint8(key) + } +} + +func BenchmarkState_GetUint16(b *testing.B) { + b.ReportAllocs() + st := newState() + n := 1000 + // Pre-populate the state with uint16 values. + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, uint16(i)) //nolint:gosec // This is a test + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetUint16(key) + } +} + +func BenchmarkState_GetUint32(b *testing.B) { + b.ReportAllocs() + st := newState() + n := 1000 + // Pre-populate the state with uint32 values. + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, uint32(i)) //nolint:gosec // This is a test + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetUint32(key) + } +} + +func BenchmarkState_GetUint64(b *testing.B) { + b.ReportAllocs() + st := newState() + n := 1000 + // Pre-populate the state with uint64 values. + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, uint64(i)) //nolint:gosec // This is a test + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetUint64(key) + } +} + +func BenchmarkState_GetUintptr(b *testing.B) { + b.ReportAllocs() + st := newState() + n := 1000 + // Pre-populate the state with uintptr values. + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, uintptr(i)) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetUintptr(key) + } +} + +func BenchmarkState_GetFloat32(b *testing.B) { + b.ReportAllocs() + st := newState() + n := 1000 + // Pre-populate the state with float32 values. + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, float32(i)) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetFloat32(key) + } +} + +func BenchmarkState_GetComplex64(b *testing.B) { + b.ReportAllocs() + st := newState() + n := 1000 + // Pre-populate the state with complex64 values. + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + // Create a complex64 value with both real and imaginary parts. + st.Set(key, complex(float32(i), float32(i))) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetComplex64(key) + } +} + +func BenchmarkState_GetComplex128(b *testing.B) { + b.ReportAllocs() + st := newState() + n := 1000 + // Pre-populate the state with complex128 values. + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + // Create a complex128 value with both real and imaginary parts. + st.Set(key, complex(float64(i), float64(i))) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetComplex128(key) + } +}