mirror of https://github.com/gofiber/fiber.git
✨ feat: Add support for application state management (#3360)
* ✨ feat: add support for application state management
* increase test coverage
* fix linter
* Fix typo
* add GetStateWithDefault helper
* add docs
* update what's new
* add has method
* fix linter
* update
* Add missing helpers for golang built-in types
* Fix lint issues
* Fix unit-tests. Update documentation
* Fix docs, add missing benchmarks
* Fix tests file
* Update default example and test
* Apply suggestions from code review
---------
Co-authored-by: Juan Calderon-Perez <835733+gaby@users.noreply.github.com>
Co-authored-by: Juan Calderon-Perez <jgcalderonperez@protonmail.com>
Co-authored-by: RW <rene@gofiber.io>
pull/3382/merge
parent
75281bd874
commit
d19e993597
10
app.go
10
app.go
|
@ -106,6 +106,8 @@ type App struct {
|
|||
tlsHandler *TLSHandler
|
||||
// Mount fields
|
||||
mountFields *mountFields
|
||||
// state management
|
||||
state *State
|
||||
// Route stack divided by HTTP methods
|
||||
stack [][]*Route
|
||||
// Route stack divided by HTTP methods and route prefixes
|
||||
|
@ -527,6 +529,9 @@ func New(config ...Config) *App {
|
|||
// Define mountFields
|
||||
app.mountFields = newMountFields(app)
|
||||
|
||||
// Define state
|
||||
app.state = newState()
|
||||
|
||||
// Override config if provided
|
||||
if len(config) > 0 {
|
||||
app.config = config[0]
|
||||
|
@ -964,6 +969,11 @@ func (app *App) Hooks() *Hooks {
|
|||
return app.hooks
|
||||
}
|
||||
|
||||
// State returns the state struct to store global data in order to share it between handlers.
|
||||
func (app *App) State() *State {
|
||||
return app.state
|
||||
}
|
||||
|
||||
var ErrTestGotEmptyResponse = errors.New("test: got empty response")
|
||||
|
||||
// TestConfig is a struct holding Test settings
|
||||
|
|
10
app_test.go
10
app_test.go
|
@ -1890,6 +1890,16 @@ func Test_Route_Naming_Issue_2671_2685(t *testing.T) {
|
|||
require.Equal(t, "/simple-route", sRoute2.Path)
|
||||
}
|
||||
|
||||
func Test_App_State(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := New()
|
||||
|
||||
app.State().Set("key", "value")
|
||||
str, ok := app.State().GetString("key")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "value", str)
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Communication_Flow -benchmem -count=4
|
||||
func Benchmark_Communication_Flow(b *testing.B) {
|
||||
app := New()
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
id: constants
|
||||
title: 📋 Constants
|
||||
description: Some constants for Fiber.
|
||||
sidebar_position: 8
|
||||
sidebar_position: 9
|
||||
---
|
||||
|
||||
### HTTP methods were copied from net/http
|
||||
|
|
|
@ -0,0 +1,640 @@
|
|||
# State Management
|
||||
|
||||
This document details the state management functionality provided by Fiber, a thread-safe global key–value store used to store application dependencies and runtime data. The implementation is based on Go's `sync.Map`, ensuring concurrency safety.
|
||||
|
||||
Below is the detailed description of all public methods and usage examples.
|
||||
|
||||
## State Type
|
||||
|
||||
`State` is a key–value store built on top of `sync.Map`. It allows storage and retrieval of dependencies and configurations in a Fiber application as well as thread–safe access to runtime data.
|
||||
|
||||
### Definition
|
||||
|
||||
```go
|
||||
// State is a key–value store for Fiber's app, used as a global storage for the app's dependencies.
|
||||
// It is a thread–safe implementation of a map[string]any, using sync.Map.
|
||||
type State struct {
|
||||
dependencies sync.Map
|
||||
}
|
||||
```
|
||||
|
||||
## Methods on State
|
||||
|
||||
### Set
|
||||
|
||||
Set adds or updates a key–value pair in the State.
|
||||
|
||||
```go
|
||||
// Set adds or updates a key–value pair in the State.
|
||||
func (s *State) Set(key string, value any)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
app.State().Set("appName", "My Fiber App")
|
||||
```
|
||||
|
||||
### Get
|
||||
|
||||
Get retrieves a value from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) Get(key string) (any, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
value, ok := app.State().Get("appName")
|
||||
if ok {
|
||||
fmt.Println("App Name:", value)
|
||||
}
|
||||
```
|
||||
|
||||
### MustGet
|
||||
|
||||
MustGet retrieves a value from the State and panics if the key is not found.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) MustGet(key string) any
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
appName := app.State().MustGet("appName")
|
||||
fmt.Println("App Name:", appName)
|
||||
```
|
||||
|
||||
### Has
|
||||
|
||||
Has checks if a key exists in the State.
|
||||
|
||||
```go title="Signature"s
|
||||
func (s *State) Has(key string) bool
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if app.State().Has("appName") {
|
||||
fmt.Println("App Name is set.")
|
||||
}
|
||||
```
|
||||
|
||||
### Delete
|
||||
|
||||
Delete removes a key–value pair from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) Delete(key string)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
app.State().Delete("obsoleteKey")
|
||||
```
|
||||
|
||||
### Reset
|
||||
|
||||
Reset removes all keys from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) Reset()
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
app.State().Reset()
|
||||
```
|
||||
|
||||
### Keys
|
||||
|
||||
Keys returns a slice containing all keys present in the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) Keys() []string
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
keys := app.State().Keys()
|
||||
fmt.Println("State Keys:", keys)
|
||||
```
|
||||
|
||||
### Len
|
||||
|
||||
Len returns the number of keys in the State.
|
||||
|
||||
```go
|
||||
// Len returns the number of keys in the State.
|
||||
func (s *State) Len() int
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
fmt.Printf("Total State Entries: %d\n", app.State().Len())
|
||||
```
|
||||
|
||||
### GetString
|
||||
|
||||
GetString retrieves a string value from the State. It returns the string and a boolean indicating a successful type assertion.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetString(key string) (string, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if appName, ok := app.State().GetString("appName"); ok {
|
||||
fmt.Println("App Name:", appName)
|
||||
}
|
||||
```
|
||||
|
||||
### GetInt
|
||||
|
||||
GetInt retrieves an integer value from the State. It returns the int and a boolean indicating a successful type assertion.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetInt(key string) (int, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if count, ok := app.State().GetInt("userCount"); ok {
|
||||
fmt.Printf("User Count: %d\n", count)
|
||||
}
|
||||
```
|
||||
|
||||
### GetBool
|
||||
|
||||
GetBool retrieves a boolean value from the State. It returns the bool and a boolean indicating a successful type assertion.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetBool(key string) (value, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if debug, ok := app.State().GetBool("debugMode"); ok {
|
||||
fmt.Printf("Debug Mode: %v\n", debug)
|
||||
}
|
||||
```
|
||||
|
||||
### GetFloat64
|
||||
|
||||
GetFloat64 retrieves a float64 value from the State. It returns the float64 and a boolean indicating a successful type assertion.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetFloat64(key string) (float64, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go title="Signature"
|
||||
if ratio, ok := app.State().GetFloat64("scalingFactor"); ok {
|
||||
fmt.Printf("Scaling Factor: %f\n", ratio)
|
||||
}
|
||||
```
|
||||
|
||||
### GetUint
|
||||
|
||||
GetUint retrieves a `uint` value from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetUint(key string) (uint, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if val, ok := app.State().GetUint("maxConnections"); ok {
|
||||
fmt.Printf("Max Connections: %d\n", val)
|
||||
}
|
||||
```
|
||||
|
||||
### GetInt8
|
||||
|
||||
GetInt8 retrieves an `int8` value from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetInt8(key string) (int8, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if val, ok := app.State().GetInt8("threshold"); ok {
|
||||
fmt.Printf("Threshold: %d\n", val)
|
||||
}
|
||||
```
|
||||
|
||||
### GetInt16
|
||||
|
||||
GetInt16 retrieves an `int16` value from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetInt16(key string) (int16, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if val, ok := app.State().GetInt16("minValue"); ok {
|
||||
fmt.Printf("Minimum Value: %d\n", val)
|
||||
}
|
||||
```
|
||||
|
||||
### GetInt32
|
||||
|
||||
GetInt32 retrieves an `int32` value from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetInt32(key string) (int32, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if val, ok := app.State().GetInt32("portNumber"); ok {
|
||||
fmt.Printf("Port Number: %d\n", val)
|
||||
}
|
||||
```
|
||||
|
||||
### GetInt64
|
||||
|
||||
GetInt64 retrieves an `int64` value from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetInt64(key string) (int64, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if val, ok := app.State().GetInt64("fileSize"); ok {
|
||||
fmt.Printf("File Size: %d\n", val)
|
||||
}
|
||||
```
|
||||
|
||||
### GetUint8
|
||||
|
||||
GetUint8 retrieves a `uint8` value from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetUint8(key string) (uint8, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if val, ok := app.State().GetUint8("byteValue"); ok {
|
||||
fmt.Printf("Byte Value: %d\n", val)
|
||||
}
|
||||
```
|
||||
|
||||
### GetUint16
|
||||
|
||||
GetUint16 retrieves a `uint16` value from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetUint16(key string) (uint16, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if val, ok := app.State().GetUint16("limit"); ok {
|
||||
fmt.Printf("Limit: %d\n", val)
|
||||
}
|
||||
```
|
||||
|
||||
### GetUint32
|
||||
|
||||
GetUint32 retrieves a `uint32` value from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetUint32(key string) (uint32, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if val, ok := app.State().GetUint32("timeout"); ok {
|
||||
fmt.Printf("Timeout: %d\n", val)
|
||||
}
|
||||
```
|
||||
|
||||
### GetUint64
|
||||
|
||||
GetUint64 retrieves a `uint64` value from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetUint64(key string) (uint64, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if val, ok := app.State().GetUint64("maxSize"); ok {
|
||||
fmt.Printf("Max Size: %d\n", val)
|
||||
}
|
||||
```
|
||||
|
||||
### GetUintptr
|
||||
|
||||
GetUintptr retrieves a `uintptr` value from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetUintptr(key string) (uintptr, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if val, ok := app.State().GetUintptr("pointerValue"); ok {
|
||||
fmt.Printf("Pointer Value: %d\n", val)
|
||||
}
|
||||
```
|
||||
|
||||
### GetFloat32
|
||||
|
||||
GetFloat32 retrieves a `float32` value from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetFloat32(key string) (float32, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if val, ok := app.State().GetFloat32("scalingFactor32"); ok {
|
||||
fmt.Printf("Scaling Factor (float32): %f\n", val)
|
||||
}
|
||||
```
|
||||
|
||||
### GetComplex64
|
||||
|
||||
GetComplex64 retrieves a `complex64` value from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetComplex64(key string) (complex64, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if val, ok := app.State().GetComplex64("complexVal"); ok {
|
||||
fmt.Printf("Complex Value (complex64): %v\n", val)
|
||||
}
|
||||
```
|
||||
|
||||
### GetComplex128
|
||||
|
||||
GetComplex128 retrieves a `complex128` value from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetComplex128(key string) (complex128, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if val, ok := app.State().GetComplex128("complexVal128"); ok {
|
||||
fmt.Printf("Complex Value (complex128): %v\n", val)
|
||||
}
|
||||
```
|
||||
|
||||
## Generic Functions
|
||||
|
||||
Fiber provides generic functions to retrieve state values with type safety and fallback options.
|
||||
|
||||
### GetState
|
||||
|
||||
GetState retrieves a value from the State and casts it to the desired type. It returns the cast value and a boolean indicating if the cast was successful.
|
||||
|
||||
```go title="Signature"
|
||||
func GetState[T any](s *State, key string) (T, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
// Retrieve an integer value safely.
|
||||
userCount, ok := GetState[int](app.State(), "userCount")
|
||||
if ok {
|
||||
fmt.Printf("User Count: %d\n", userCount)
|
||||
}
|
||||
```
|
||||
|
||||
### MustGetState
|
||||
|
||||
MustGetState retrieves a value from the State and casts it to the desired type. It panics if the key is not found or if the type assertion fails.
|
||||
|
||||
```go title="Signature"
|
||||
func MustGetState[T any](s *State, key string) T
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
// Retrieve the value or panic if it is not present.
|
||||
config := MustGetState[string](app.State(), "configFile")
|
||||
fmt.Println("Config File:", config)
|
||||
```
|
||||
|
||||
### GetStateWithDefault
|
||||
|
||||
GetStateWithDefault retrieves a value from the State, casting it to the desired type. If the key is not present, it returns the provided default value.
|
||||
|
||||
```go title="Signature"
|
||||
func GetStateWithDefault[T any](s *State, key string, defaultVal T) T
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
// Retrieve a value with a default fallback.
|
||||
requestCount := GetStateWithDefault[int](app.State(), "requestCount", 0)
|
||||
fmt.Printf("Request Count: %d\n", requestCount)
|
||||
```
|
||||
|
||||
## Comprehensive Examples
|
||||
|
||||
### Example: Request Counter
|
||||
|
||||
This example demonstrates how to track the number of requests using the State.
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gofiber/fiber/v3"
|
||||
)
|
||||
|
||||
func main() {
|
||||
app := fiber.New()
|
||||
|
||||
// Initialize state with a counter.
|
||||
app.State().Set("requestCount", 0)
|
||||
|
||||
// Middleware: Increase counter for every request.
|
||||
app.Use(func(c fiber.Ctx) error {
|
||||
count, _ := c.App().State().GetInt("requestCount")
|
||||
app.State().Set("requestCount", count+1)
|
||||
return c.Next()
|
||||
})
|
||||
|
||||
app.Get("/", func(c fiber.Ctx) error {
|
||||
return c.SendString("Hello World!")
|
||||
})
|
||||
|
||||
app.Get("/stats", func(c fiber.Ctx) error {
|
||||
count, _ := c.App().State().Get("requestCount")
|
||||
return c.SendString(fmt.Sprintf("Total requests: %d", count))
|
||||
})
|
||||
|
||||
app.Listen(":3000")
|
||||
}
|
||||
```
|
||||
|
||||
### Example: Environment–Specific Configuration
|
||||
|
||||
This example shows how to configure different settings based on the environment.
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/gofiber/fiber/v3"
|
||||
)
|
||||
|
||||
func main() {
|
||||
app := fiber.New()
|
||||
|
||||
// Determine environment.
|
||||
environment := os.Getenv("ENV")
|
||||
if environment == "" {
|
||||
environment = "development"
|
||||
}
|
||||
app.State().Set("environment", environment)
|
||||
|
||||
// Set environment-specific configurations.
|
||||
if environment == "development" {
|
||||
app.State().Set("apiUrl", "http://localhost:8080/api")
|
||||
app.State().Set("debug", true)
|
||||
} else {
|
||||
app.State().Set("apiUrl", "https://api.production.com")
|
||||
app.State().Set("debug", false)
|
||||
}
|
||||
|
||||
app.Get("/config", func(c fiber.Ctx) error {
|
||||
config := map[string]any{
|
||||
"environment": environment,
|
||||
"apiUrl": fiber.GetStateWithDefault(c.App().State(), "apiUrl", ""),
|
||||
"debug": fiber.GetStateWithDefault(c.App().State(), "debug", false),
|
||||
}
|
||||
return c.JSON(config)
|
||||
})
|
||||
|
||||
app.Listen(":3000")
|
||||
}
|
||||
```
|
||||
|
||||
### Example: Dependency Injection with State Management
|
||||
|
||||
This example demonstrates how to use the State for dependency injection in a Fiber application.
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/gofiber/fiber/v3"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID int `query:"id"`
|
||||
Name string `query:"name"`
|
||||
Email string `query:"email"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
app := fiber.New()
|
||||
ctx := context.Background()
|
||||
|
||||
// Initialize Redis client.
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: "localhost:6379",
|
||||
Password: "",
|
||||
DB: 0,
|
||||
})
|
||||
|
||||
// Check the Redis connection.
|
||||
if err := rdb.Ping(ctx).Err(); err != nil {
|
||||
log.Fatalf("Could not connect to Redis: %v", err)
|
||||
}
|
||||
|
||||
// Inject the Redis client into Fiber's State for dependency injection.
|
||||
app.State().Set("redis", rdb)
|
||||
|
||||
app.Get("/user/create", func(c fiber.Ctx) error {
|
||||
var user User
|
||||
if err := c.Bind().Query(&user); err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).SendString(err.Error())
|
||||
}
|
||||
|
||||
// Save the user to the database.
|
||||
rdb, ok := fiber.GetState[*redis.Client](c.App().State(), "redis")
|
||||
if !ok {
|
||||
return c.Status(fiber.StatusInternalServerError).SendString("Redis client not found")
|
||||
}
|
||||
|
||||
// Save the user to the database.
|
||||
key := fmt.Sprintf("user:%d", user.ID)
|
||||
err := rdb.HSet(ctx, key, "name", user.Name, "email", user.Email).Err()
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
|
||||
}
|
||||
|
||||
return c.JSON(user)
|
||||
})
|
||||
|
||||
app.Get("/user/:id", func(c fiber.Ctx) error {
|
||||
id := c.Params("id")
|
||||
|
||||
rdb, ok := fiber.GetState[*redis.Client](c.App().State(), "redis")
|
||||
if !ok {
|
||||
return c.Status(fiber.StatusInternalServerError).SendString("Redis client not found")
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("user:%s", id)
|
||||
user, err := rdb.HGetAll(ctx, key).Result()
|
||||
if err == redis.Nil {
|
||||
return c.Status(fiber.StatusNotFound).SendString("User not found")
|
||||
} else if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
|
||||
}
|
||||
|
||||
return c.JSON(user)
|
||||
})
|
||||
|
||||
app.Listen(":3000")
|
||||
}
|
||||
```
|
|
@ -59,6 +59,7 @@ We have made several changes to the Fiber app, including:
|
|||
- **RegisterCustomBinder**: Allows for the registration of custom binders.
|
||||
- **RegisterCustomConstraint**: Allows for the registration of custom constraints.
|
||||
- **NewCtxFunc**: Introduces a new context function.
|
||||
- **State**: Provides a global state for the application, which can be used to store and retrieve data across the application. Check out the [State](./api/state) method for further details.
|
||||
|
||||
### Removed Methods
|
||||
|
||||
|
|
|
@ -0,0 +1,322 @@
|
|||
package fiber
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// State is a key-value store for Fiber's app in order to be used as a global storage for the app's dependencies.
|
||||
// It's a thread-safe implementation of a map[string]any, using sync.Map.
|
||||
type State struct {
|
||||
dependencies sync.Map
|
||||
}
|
||||
|
||||
// NewState creates a new instance of State.
|
||||
func newState() *State {
|
||||
return &State{
|
||||
dependencies: sync.Map{},
|
||||
}
|
||||
}
|
||||
|
||||
// Set sets a key-value pair in the State.
|
||||
func (s *State) Set(key string, value any) {
|
||||
s.dependencies.Store(key, value)
|
||||
}
|
||||
|
||||
// Get retrieves a value from the State.
|
||||
func (s *State) Get(key string) (any, bool) {
|
||||
return s.dependencies.Load(key)
|
||||
}
|
||||
|
||||
// MustGet retrieves a value from the State and panics if the key is not found.
|
||||
func (s *State) MustGet(key string) any {
|
||||
if dep, ok := s.Get(key); ok {
|
||||
return dep
|
||||
}
|
||||
|
||||
panic("state: dependency not found!")
|
||||
}
|
||||
|
||||
// Has checks if a key is present in the State.
|
||||
// It returns a boolean indicating if the key is present.
|
||||
func (s *State) Has(key string) bool {
|
||||
_, ok := s.Get(key)
|
||||
return ok
|
||||
}
|
||||
|
||||
// Delete removes a key-value pair from the State.
|
||||
func (s *State) Delete(key string) {
|
||||
s.dependencies.Delete(key)
|
||||
}
|
||||
|
||||
// Reset resets the State by removing all keys.
|
||||
func (s *State) Reset() {
|
||||
s.dependencies.Clear()
|
||||
}
|
||||
|
||||
// Keys returns a slice containing all keys present in the State.
|
||||
func (s *State) Keys() []string {
|
||||
keys := make([]string, 0)
|
||||
s.dependencies.Range(func(key, _ any) bool {
|
||||
keyStr, ok := key.(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
keys = append(keys, keyStr)
|
||||
return true
|
||||
})
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
// Len returns the number of keys in the State.
|
||||
func (s *State) Len() int {
|
||||
length := 0
|
||||
s.dependencies.Range(func(_, _ any) bool {
|
||||
length++
|
||||
return true
|
||||
})
|
||||
|
||||
return length
|
||||
}
|
||||
|
||||
// GetState retrieves a value from the State and casts it to the desired type.
|
||||
// It returns the casted value and a boolean indicating if the cast was successful.
|
||||
func GetState[T any](s *State, key string) (T, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
|
||||
if ok {
|
||||
depT, okCast := dep.(T)
|
||||
return depT, okCast
|
||||
}
|
||||
|
||||
var zeroVal T
|
||||
return zeroVal, false
|
||||
}
|
||||
|
||||
// MustGetState retrieves a value from the State and casts it to the desired type.
|
||||
// It panics if the key is not found or if the type assertion fails.
|
||||
func MustGetState[T any](s *State, key string) T {
|
||||
dep, ok := GetState[T](s, key)
|
||||
if !ok {
|
||||
panic("state: dependency not found!")
|
||||
}
|
||||
|
||||
return dep
|
||||
}
|
||||
|
||||
// GetStateWithDefault retrieves a value from the State,
|
||||
// casting it to the desired type. If the key is not present,
|
||||
// it returns the provided default value.
|
||||
func GetStateWithDefault[T any](s *State, key string, defaultVal T) T {
|
||||
dep, ok := GetState[T](s, key)
|
||||
if !ok {
|
||||
return defaultVal
|
||||
}
|
||||
|
||||
return dep
|
||||
}
|
||||
|
||||
// GetString retrieves a string value from the State.
|
||||
// It returns the string and a boolean indicating successful type assertion.
|
||||
func (s *State) GetString(key string) (string, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
depString, okCast := dep.(string)
|
||||
return depString, okCast
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
// GetInt retrieves an integer value from the State.
|
||||
// It returns the int and a boolean indicating successful type assertion.
|
||||
func (s *State) GetInt(key string) (int, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
depInt, okCast := dep.(int)
|
||||
return depInt, okCast
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetBool retrieves a boolean value from the State.
|
||||
// It returns the bool and a boolean indicating successful type assertion.
|
||||
func (s *State) GetBool(key string) (value, ok bool) { //nolint:nonamedreturns // Better idea to use named returns here
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
depBool, okCast := dep.(bool)
|
||||
return depBool, okCast
|
||||
}
|
||||
|
||||
return false, false
|
||||
}
|
||||
|
||||
// GetFloat64 retrieves a float64 value from the State.
|
||||
// It returns the float64 and a boolean indicating successful type assertion.
|
||||
func (s *State) GetFloat64(key string) (float64, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
depFloat64, okCast := dep.(float64)
|
||||
return depFloat64, okCast
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetUint retrieves a uint value from the State.
|
||||
// It returns the float64 and a boolean indicating successful type assertion.
|
||||
func (s *State) GetUint(key string) (uint, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
if depUint, okCast := dep.(uint); okCast {
|
||||
return depUint, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetInt8 retrieves an int8 value from the State.
|
||||
// It returns the float64 and a boolean indicating successful type assertion.
|
||||
func (s *State) GetInt8(key string) (int8, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
if depInt8, okCast := dep.(int8); okCast {
|
||||
return depInt8, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetInt16 retrieves an int16 value from the State.
|
||||
// It returns the float64 and a boolean indicating successful type assertion.
|
||||
func (s *State) GetInt16(key string) (int16, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
if depInt16, okCast := dep.(int16); okCast {
|
||||
return depInt16, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetInt32 retrieves an int32 value from the State.
|
||||
// It returns the float64 and a boolean indicating successful type assertion.
|
||||
func (s *State) GetInt32(key string) (int32, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
if depInt32, okCast := dep.(int32); okCast {
|
||||
return depInt32, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetInt64 retrieves an int64 value from the State.
|
||||
// It returns the float64 and a boolean indicating successful type assertion.
|
||||
func (s *State) GetInt64(key string) (int64, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
if depInt64, okCast := dep.(int64); okCast {
|
||||
return depInt64, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetUint8 retrieves a uint8 value from the State.
|
||||
// It returns the float64 and a boolean indicating successful type assertion.
|
||||
func (s *State) GetUint8(key string) (uint8, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
if depUint8, okCast := dep.(uint8); okCast {
|
||||
return depUint8, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetUint16 retrieves a uint16 value from the State.
|
||||
// It returns the float64 and a boolean indicating successful type assertion.
|
||||
func (s *State) GetUint16(key string) (uint16, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
if depUint16, okCast := dep.(uint16); okCast {
|
||||
return depUint16, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetUint32 retrieves a uint32 value from the State.
|
||||
// It returns the float64 and a boolean indicating successful type assertion.
|
||||
func (s *State) GetUint32(key string) (uint32, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
if depUint32, okCast := dep.(uint32); okCast {
|
||||
return depUint32, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetUint64 retrieves a uint64 value from the State.
|
||||
// It returns the float64 and a boolean indicating successful type assertion.
|
||||
func (s *State) GetUint64(key string) (uint64, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
if depUint64, okCast := dep.(uint64); okCast {
|
||||
return depUint64, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetUintptr retrieves a uintptr value from the State.
|
||||
// It returns the float64 and a boolean indicating successful type assertion.
|
||||
func (s *State) GetUintptr(key string) (uintptr, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
if depUintptr, okCast := dep.(uintptr); okCast {
|
||||
return depUintptr, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetFloat32 retrieves a float32 value from the State.
|
||||
// It returns the float64 and a boolean indicating successful type assertion.
|
||||
func (s *State) GetFloat32(key string) (float32, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
if depFloat32, okCast := dep.(float32); okCast {
|
||||
return depFloat32, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetComplex64 retrieves a complex64 value from the State.
|
||||
// It returns the float64 and a boolean indicating successful type assertion.
|
||||
func (s *State) GetComplex64(key string) (complex64, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
if depComplex64, okCast := dep.(complex64); okCast {
|
||||
return depComplex64, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetComplex128 retrieves a complex128 value from the State.
|
||||
// It returns the float64 and a boolean indicating successful type assertion.
|
||||
func (s *State) GetComplex128(key string) (complex128, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
if depComplex128, okCast := dep.(complex128); okCast {
|
||||
return depComplex128, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
|
@ -0,0 +1,981 @@
|
|||
package fiber
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestState_SetAndGet_WithApp(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Create app
|
||||
app := New()
|
||||
|
||||
// test setting and getting a value
|
||||
app.State().Set("foo", "bar")
|
||||
val, ok := app.State().Get("foo")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "bar", val)
|
||||
|
||||
// test key not found
|
||||
_, ok = app.State().Get("unknown")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestState_SetAndGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
// test setting and getting a value
|
||||
st.Set("foo", "bar")
|
||||
val, ok := st.Get("foo")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "bar", val)
|
||||
|
||||
// test key not found
|
||||
_, ok = st.Get("unknown")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestState_GetString(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("str", "hello")
|
||||
s, ok := st.GetString("str")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "hello", s)
|
||||
|
||||
// wrong type should return false
|
||||
st.Set("num", 123)
|
||||
s, ok = st.GetString("num")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, "", s)
|
||||
|
||||
// missing key should return false
|
||||
s, ok = st.GetString("missing")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, "", s)
|
||||
}
|
||||
|
||||
func TestState_GetInt(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("num", 456)
|
||||
i, ok := st.GetInt("num")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, 456, i)
|
||||
|
||||
// wrong type should return zero value
|
||||
st.Set("str", "abc")
|
||||
i, ok = st.GetInt("str")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, 0, i)
|
||||
|
||||
// missing key should return zero value
|
||||
i, ok = st.GetInt("missing")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, 0, i)
|
||||
}
|
||||
|
||||
func TestState_GetBool(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("flag", true)
|
||||
b, ok := st.GetBool("flag")
|
||||
require.True(t, ok)
|
||||
require.True(t, b)
|
||||
|
||||
// wrong type
|
||||
st.Set("num", 1)
|
||||
b, ok = st.GetBool("num")
|
||||
require.False(t, ok)
|
||||
require.False(t, b)
|
||||
|
||||
// missing key should return false
|
||||
b, ok = st.GetBool("missing")
|
||||
require.False(t, ok)
|
||||
require.False(t, b)
|
||||
}
|
||||
|
||||
func TestState_GetFloat64(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("pi", 3.14)
|
||||
f, ok := st.GetFloat64("pi")
|
||||
require.True(t, ok)
|
||||
require.InDelta(t, 3.14, f, 0.0001)
|
||||
|
||||
// wrong type should return zero value
|
||||
st.Set("int", 10)
|
||||
f, ok = st.GetFloat64("int")
|
||||
require.False(t, ok)
|
||||
require.InDelta(t, 0.0, f, 0.0001)
|
||||
|
||||
// missing key should return zero value
|
||||
f, ok = st.GetFloat64("missing")
|
||||
require.False(t, ok)
|
||||
require.InDelta(t, 0.0, f, 0.0001)
|
||||
}
|
||||
|
||||
func TestState_GetUint(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("uint", uint(100))
|
||||
u, ok := st.GetUint("uint")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, uint(100), u)
|
||||
|
||||
st.Set("wrong", "not uint")
|
||||
u, ok = st.GetUint("wrong")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uint(0), u)
|
||||
|
||||
u, ok = st.GetUint("missing")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uint(0), u)
|
||||
}
|
||||
|
||||
func TestState_GetInt8(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("int8", int8(10))
|
||||
i, ok := st.GetInt8("int8")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, int8(10), i)
|
||||
|
||||
st.Set("wrong", "not int8")
|
||||
i, ok = st.GetInt8("wrong")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, int8(0), i)
|
||||
|
||||
i, ok = st.GetInt8("missing")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, int8(0), i)
|
||||
}
|
||||
|
||||
func TestState_GetInt16(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("int16", int16(200))
|
||||
i, ok := st.GetInt16("int16")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, int16(200), i)
|
||||
|
||||
st.Set("wrong", "not int16")
|
||||
i, ok = st.GetInt16("wrong")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, int16(0), i)
|
||||
|
||||
i, ok = st.GetInt16("missing")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, int16(0), i)
|
||||
}
|
||||
|
||||
func TestState_GetInt32(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("int32", int32(3000))
|
||||
i, ok := st.GetInt32("int32")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, int32(3000), i)
|
||||
|
||||
st.Set("wrong", "not int32")
|
||||
i, ok = st.GetInt32("wrong")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, int32(0), i)
|
||||
|
||||
i, ok = st.GetInt32("missing")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, int32(0), i)
|
||||
}
|
||||
|
||||
func TestState_GetInt64(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("int64", int64(4000))
|
||||
i, ok := st.GetInt64("int64")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, int64(4000), i)
|
||||
|
||||
st.Set("wrong", "not int64")
|
||||
i, ok = st.GetInt64("wrong")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, int64(0), i)
|
||||
|
||||
i, ok = st.GetInt64("missing")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, int64(0), i)
|
||||
}
|
||||
|
||||
func TestState_GetUint8(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("uint8", uint8(20))
|
||||
u, ok := st.GetUint8("uint8")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, uint8(20), u)
|
||||
|
||||
st.Set("wrong", "not uint8")
|
||||
u, ok = st.GetUint8("wrong")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uint8(0), u)
|
||||
|
||||
u, ok = st.GetUint8("missing")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uint8(0), u)
|
||||
}
|
||||
|
||||
func TestState_GetUint16(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("uint16", uint16(300))
|
||||
u, ok := st.GetUint16("uint16")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, uint16(300), u)
|
||||
|
||||
st.Set("wrong", "not uint16")
|
||||
u, ok = st.GetUint16("wrong")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uint16(0), u)
|
||||
|
||||
u, ok = st.GetUint16("missing")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uint16(0), u)
|
||||
}
|
||||
|
||||
func TestState_GetUint32(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("uint32", uint32(400000))
|
||||
u, ok := st.GetUint32("uint32")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, uint32(400000), u)
|
||||
|
||||
st.Set("wrong", "not uint32")
|
||||
u, ok = st.GetUint32("wrong")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uint32(0), u)
|
||||
|
||||
u, ok = st.GetUint32("missing")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uint32(0), u)
|
||||
}
|
||||
|
||||
func TestState_GetUint64(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("uint64", uint64(5000000))
|
||||
u, ok := st.GetUint64("uint64")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, uint64(5000000), u)
|
||||
|
||||
st.Set("wrong", "not uint64")
|
||||
u, ok = st.GetUint64("wrong")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uint64(0), u)
|
||||
|
||||
u, ok = st.GetUint64("missing")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uint64(0), u)
|
||||
}
|
||||
|
||||
func TestState_GetUintptr(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
var ptr uintptr = 12345
|
||||
st.Set("uintptr", ptr)
|
||||
u, ok := st.GetUintptr("uintptr")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, ptr, u)
|
||||
|
||||
st.Set("wrong", "not uintptr")
|
||||
u, ok = st.GetUintptr("wrong")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uintptr(0), u)
|
||||
|
||||
u, ok = st.GetUintptr("missing")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uintptr(0), u)
|
||||
}
|
||||
|
||||
func TestState_GetFloat32(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("float32", float32(3.14))
|
||||
f, ok := st.GetFloat32("float32")
|
||||
require.True(t, ok)
|
||||
require.InDelta(t, float32(3.14), f, 0.0001)
|
||||
|
||||
st.Set("wrong", "not float32")
|
||||
f, ok = st.GetFloat32("wrong")
|
||||
require.False(t, ok)
|
||||
require.InDelta(t, float32(0), f, 0.0001)
|
||||
|
||||
f, ok = st.GetFloat32("missing")
|
||||
require.False(t, ok)
|
||||
require.InDelta(t, float32(0), f, 0.0001)
|
||||
}
|
||||
|
||||
func TestState_GetComplex64(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
var c complex64 = complex(2, 3)
|
||||
st.Set("complex64", c)
|
||||
cRes, ok := st.GetComplex64("complex64")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, c, cRes)
|
||||
|
||||
st.Set("wrong", "not complex64")
|
||||
cRes, ok = st.GetComplex64("wrong")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, complex64(0), cRes)
|
||||
|
||||
cRes, ok = st.GetComplex64("missing")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, complex64(0), cRes)
|
||||
}
|
||||
|
||||
func TestState_GetComplex128(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
c := complex(4, 5)
|
||||
st.Set("complex128", c)
|
||||
cRes, ok := st.GetComplex128("complex128")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, c, cRes)
|
||||
|
||||
st.Set("wrong", "not complex128")
|
||||
cRes, ok = st.GetComplex128("wrong")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, complex128(0), cRes)
|
||||
|
||||
cRes, ok = st.GetComplex128("missing")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, complex128(0), cRes)
|
||||
}
|
||||
|
||||
func TestState_MustGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("exists", "value")
|
||||
val := st.MustGet("exists")
|
||||
require.Equal(t, "value", val)
|
||||
|
||||
// must-get on missing key should panic
|
||||
require.Panics(t, func() {
|
||||
_ = st.MustGet("missing")
|
||||
})
|
||||
}
|
||||
|
||||
func TestState_Has(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
st := newState()
|
||||
|
||||
st.Set("key", "value")
|
||||
require.True(t, st.Has("key"))
|
||||
}
|
||||
|
||||
func TestState_Delete(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("key", "value")
|
||||
st.Delete("key")
|
||||
_, ok := st.Get("key")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestState_Reset(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("a", 1)
|
||||
st.Set("b", 2)
|
||||
st.Reset()
|
||||
require.Equal(t, 0, st.Len())
|
||||
require.Empty(t, st.Keys())
|
||||
}
|
||||
|
||||
func TestState_Keys(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
keys := []string{"one", "two", "three"}
|
||||
for _, k := range keys {
|
||||
st.Set(k, k)
|
||||
}
|
||||
|
||||
returnedKeys := st.Keys()
|
||||
require.ElementsMatch(t, keys, returnedKeys)
|
||||
}
|
||||
|
||||
func TestState_Len(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
require.Equal(t, 0, st.Len())
|
||||
|
||||
st.Set("a", "a")
|
||||
require.Equal(t, 1, st.Len())
|
||||
|
||||
st.Set("b", "b")
|
||||
require.Equal(t, 2, st.Len())
|
||||
|
||||
st.Delete("a")
|
||||
require.Equal(t, 1, st.Len())
|
||||
}
|
||||
|
||||
type testCase[T any] struct { //nolint:govet // It does not really matter for test
|
||||
name string
|
||||
key string
|
||||
value any
|
||||
expected T
|
||||
ok bool
|
||||
}
|
||||
|
||||
func runGenericTest[T any](t *testing.T, getter func(*State, string) (T, bool), tests []testCase[T]) {
|
||||
t.Helper()
|
||||
|
||||
st := newState()
|
||||
for _, tc := range tests {
|
||||
st.Set(tc.key, tc.value)
|
||||
got, ok := getter(st, tc.key)
|
||||
require.Equal(t, tc.ok, ok, tc.name)
|
||||
require.Equal(t, tc.expected, got, tc.name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestState_GetGeneric(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runGenericTest[int](t, GetState[int], []testCase[int]{
|
||||
{"int correct conversion", "num", 42, 42, true},
|
||||
{"int wrong conversion from string", "str", "abc", 0, false},
|
||||
})
|
||||
|
||||
runGenericTest[string](t, GetState[string], []testCase[string]{
|
||||
{"string correct conversion", "strVal", "hello", "hello", true},
|
||||
{"string wrong conversion from int", "intVal", 100, "", false},
|
||||
})
|
||||
|
||||
runGenericTest[bool](t, GetState[bool], []testCase[bool]{
|
||||
{"bool correct conversion", "flag", true, true, true},
|
||||
{"bool wrong conversion from int", "intFlag", 1, false, false},
|
||||
})
|
||||
|
||||
runGenericTest[float64](t, GetState[float64], []testCase[float64]{
|
||||
{"float64 correct conversion", "pi", 3.14, 3.14, true},
|
||||
{"float64 wrong conversion from int", "intVal", 10, 0.0, false},
|
||||
})
|
||||
}
|
||||
|
||||
func Test_MustGetStateGeneric(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("flag", true)
|
||||
flag := MustGetState[bool](st, "flag")
|
||||
require.True(t, flag)
|
||||
|
||||
// mismatched type should panic
|
||||
require.Panics(t, func() {
|
||||
_ = MustGetState[string](st, "flag")
|
||||
})
|
||||
|
||||
// missing key should also panic
|
||||
require.Panics(t, func() {
|
||||
_ = MustGetState[string](st, "missing")
|
||||
})
|
||||
}
|
||||
|
||||
func Test_GetStateWithDefault(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("flag", true)
|
||||
flag := GetStateWithDefault(st, "flag", false)
|
||||
require.True(t, flag)
|
||||
|
||||
// mismatched type should return the default value
|
||||
str := GetStateWithDefault(st, "flag", "default")
|
||||
require.Equal(t, "default", str)
|
||||
|
||||
// missing key should return the default value
|
||||
flag = GetStateWithDefault(st, "missing", false)
|
||||
require.False(t, flag)
|
||||
}
|
||||
|
||||
func BenchmarkState_Set(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, i)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_Get(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
// pre-populate the state
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.Get(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetString(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
// pre-populate the state
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, strconv.Itoa(i))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetString(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetInt(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
// pre-populate the state
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetInt(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetBool(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
// pre-populate the state
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, i%2 == 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetBool(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetFloat64(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
// pre-populate the state
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, float64(i))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetFloat64(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_MustGet(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
// pre-populate the state
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.MustGet(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetStateGeneric(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
// pre-populate the state
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
GetState[int](st, key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_MustGetStateGeneric(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
// pre-populate the state
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
MustGetState[int](st, key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetStateWithDefault(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
// pre-populate the state
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
GetStateWithDefault[int](st, key, 0)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_Has(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
// pre-populate the state
|
||||
for i := 0; i < n; i++ {
|
||||
st.Set("key"+strconv.Itoa(i), i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
st.Has("key" + strconv.Itoa(i%n))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_Delete(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
st := newState()
|
||||
st.Set("a", 1)
|
||||
st.Delete("a")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_Reset(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
st := newState()
|
||||
// add a fixed number of keys before clearing
|
||||
for j := 0; j < 100; j++ {
|
||||
st.Set("key"+strconv.Itoa(j), j)
|
||||
}
|
||||
st.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_Keys(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
for i := 0; i < n; i++ {
|
||||
st.Set("key"+strconv.Itoa(i), i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = st.Keys()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_Len(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
for i := 0; i < n; i++ {
|
||||
st.Set("key"+strconv.Itoa(i), i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = st.Len()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetUint(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
st := newState()
|
||||
n := 1000
|
||||
// Pre-populate the state with uint values.
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, uint(i)) //nolint:gosec // This is a test
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetUint(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetInt8(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
st := newState()
|
||||
n := 1000
|
||||
// Pre-populate the state with int8 values (using modulo to stay in range).
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, int8(i%128)) //nolint:gosec // This is a test
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetInt8(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetInt16(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
st := newState()
|
||||
n := 1000
|
||||
// Pre-populate the state with int16 values.
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, int16(i)) //nolint:gosec // This is a test
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetInt16(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetInt32(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
st := newState()
|
||||
n := 1000
|
||||
// Pre-populate the state with int32 values.
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, int32(i)) //nolint:gosec // This is a test
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetInt32(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetInt64(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
st := newState()
|
||||
n := 1000
|
||||
// Pre-populate the state with int64 values.
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, int64(i))
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetInt64(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetUint8(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
st := newState()
|
||||
n := 1000
|
||||
// Pre-populate the state with uint8 values.
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, uint8(i%256)) //nolint:gosec // This is a test
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetUint8(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetUint16(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
st := newState()
|
||||
n := 1000
|
||||
// Pre-populate the state with uint16 values.
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, uint16(i)) //nolint:gosec // This is a test
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetUint16(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetUint32(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
st := newState()
|
||||
n := 1000
|
||||
// Pre-populate the state with uint32 values.
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, uint32(i)) //nolint:gosec // This is a test
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetUint32(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetUint64(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
st := newState()
|
||||
n := 1000
|
||||
// Pre-populate the state with uint64 values.
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, uint64(i)) //nolint:gosec // This is a test
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetUint64(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetUintptr(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
st := newState()
|
||||
n := 1000
|
||||
// Pre-populate the state with uintptr values.
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, uintptr(i))
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetUintptr(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetFloat32(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
st := newState()
|
||||
n := 1000
|
||||
// Pre-populate the state with float32 values.
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, float32(i))
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetFloat32(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetComplex64(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
st := newState()
|
||||
n := 1000
|
||||
// Pre-populate the state with complex64 values.
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
// Create a complex64 value with both real and imaginary parts.
|
||||
st.Set(key, complex(float32(i), float32(i)))
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetComplex64(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetComplex128(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
st := newState()
|
||||
n := 1000
|
||||
// Pre-populate the state with complex128 values.
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
// Create a complex128 value with both real and imaginary parts.
|
||||
st.Set(key, complex(float64(i), float64(i)))
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetComplex128(key)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue