mirror of https://github.com/gofiber/fiber.git
Merge branch 'feature/improve-check-constraint' of github.com:JIeJaitt/fiber into feature/improve-check-constraint
commit
6e399dc986
10
app.go
10
app.go
|
@ -106,6 +106,8 @@ type App struct {
|
|||
tlsHandler *TLSHandler
|
||||
// Mount fields
|
||||
mountFields *mountFields
|
||||
// state management
|
||||
state *State
|
||||
// Route stack divided by HTTP methods
|
||||
stack [][]*Route
|
||||
// Route stack divided by HTTP methods and route prefixes
|
||||
|
@ -527,6 +529,9 @@ func New(config ...Config) *App {
|
|||
// Define mountFields
|
||||
app.mountFields = newMountFields(app)
|
||||
|
||||
// Define state
|
||||
app.state = newState()
|
||||
|
||||
// Override config if provided
|
||||
if len(config) > 0 {
|
||||
app.config = config[0]
|
||||
|
@ -964,6 +969,11 @@ func (app *App) Hooks() *Hooks {
|
|||
return app.hooks
|
||||
}
|
||||
|
||||
// State returns the state struct to store global data in order to share it between handlers.
|
||||
func (app *App) State() *State {
|
||||
return app.state
|
||||
}
|
||||
|
||||
var ErrTestGotEmptyResponse = errors.New("test: got empty response")
|
||||
|
||||
// TestConfig is a struct holding Test settings
|
||||
|
|
10
app_test.go
10
app_test.go
|
@ -1890,6 +1890,16 @@ func Test_Route_Naming_Issue_2671_2685(t *testing.T) {
|
|||
require.Equal(t, "/simple-route", sRoute2.Path)
|
||||
}
|
||||
|
||||
func Test_App_State(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := New()
|
||||
|
||||
app.State().Set("key", "value")
|
||||
str, ok := app.State().GetString("key")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "value", str)
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Communication_Flow -benchmem -count=4
|
||||
func Benchmark_Communication_Flow(b *testing.B) {
|
||||
app := New()
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
id: constants
|
||||
title: 📋 Constants
|
||||
description: Some constants for Fiber.
|
||||
sidebar_position: 8
|
||||
sidebar_position: 9
|
||||
---
|
||||
|
||||
### HTTP methods were copied from net/http
|
||||
|
|
|
@ -0,0 +1,640 @@
|
|||
# State Management
|
||||
|
||||
This document details the state management functionality provided by Fiber, a thread-safe global key–value store used to store application dependencies and runtime data. The implementation is based on Go's `sync.Map`, ensuring concurrency safety.
|
||||
|
||||
Below is the detailed description of all public methods and usage examples.
|
||||
|
||||
## State Type
|
||||
|
||||
`State` is a key–value store built on top of `sync.Map`. It allows storage and retrieval of dependencies and configurations in a Fiber application as well as thread–safe access to runtime data.
|
||||
|
||||
### Definition
|
||||
|
||||
```go
|
||||
// State is a key–value store for Fiber's app, used as a global storage for the app's dependencies.
|
||||
// It is a thread–safe implementation of a map[string]any, using sync.Map.
|
||||
type State struct {
|
||||
dependencies sync.Map
|
||||
}
|
||||
```
|
||||
|
||||
## Methods on State
|
||||
|
||||
### Set
|
||||
|
||||
Set adds or updates a key–value pair in the State.
|
||||
|
||||
```go
|
||||
// Set adds or updates a key–value pair in the State.
|
||||
func (s *State) Set(key string, value any)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
app.State().Set("appName", "My Fiber App")
|
||||
```
|
||||
|
||||
### Get
|
||||
|
||||
Get retrieves a value from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) Get(key string) (any, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
value, ok := app.State().Get("appName")
|
||||
if ok {
|
||||
fmt.Println("App Name:", value)
|
||||
}
|
||||
```
|
||||
|
||||
### MustGet
|
||||
|
||||
MustGet retrieves a value from the State and panics if the key is not found.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) MustGet(key string) any
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
appName := app.State().MustGet("appName")
|
||||
fmt.Println("App Name:", appName)
|
||||
```
|
||||
|
||||
### Has
|
||||
|
||||
Has checks if a key exists in the State.
|
||||
|
||||
```go title="Signature"s
|
||||
func (s *State) Has(key string) bool
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if app.State().Has("appName") {
|
||||
fmt.Println("App Name is set.")
|
||||
}
|
||||
```
|
||||
|
||||
### Delete
|
||||
|
||||
Delete removes a key–value pair from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) Delete(key string)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
app.State().Delete("obsoleteKey")
|
||||
```
|
||||
|
||||
### Reset
|
||||
|
||||
Reset removes all keys from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) Reset()
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
app.State().Reset()
|
||||
```
|
||||
|
||||
### Keys
|
||||
|
||||
Keys returns a slice containing all keys present in the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) Keys() []string
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
keys := app.State().Keys()
|
||||
fmt.Println("State Keys:", keys)
|
||||
```
|
||||
|
||||
### Len
|
||||
|
||||
Len returns the number of keys in the State.
|
||||
|
||||
```go
|
||||
// Len returns the number of keys in the State.
|
||||
func (s *State) Len() int
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
fmt.Printf("Total State Entries: %d\n", app.State().Len())
|
||||
```
|
||||
|
||||
### GetString
|
||||
|
||||
GetString retrieves a string value from the State. It returns the string and a boolean indicating a successful type assertion.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetString(key string) (string, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if appName, ok := app.State().GetString("appName"); ok {
|
||||
fmt.Println("App Name:", appName)
|
||||
}
|
||||
```
|
||||
|
||||
### GetInt
|
||||
|
||||
GetInt retrieves an integer value from the State. It returns the int and a boolean indicating a successful type assertion.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetInt(key string) (int, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if count, ok := app.State().GetInt("userCount"); ok {
|
||||
fmt.Printf("User Count: %d\n", count)
|
||||
}
|
||||
```
|
||||
|
||||
### GetBool
|
||||
|
||||
GetBool retrieves a boolean value from the State. It returns the bool and a boolean indicating a successful type assertion.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetBool(key string) (value, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if debug, ok := app.State().GetBool("debugMode"); ok {
|
||||
fmt.Printf("Debug Mode: %v\n", debug)
|
||||
}
|
||||
```
|
||||
|
||||
### GetFloat64
|
||||
|
||||
GetFloat64 retrieves a float64 value from the State. It returns the float64 and a boolean indicating a successful type assertion.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetFloat64(key string) (float64, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go title="Signature"
|
||||
if ratio, ok := app.State().GetFloat64("scalingFactor"); ok {
|
||||
fmt.Printf("Scaling Factor: %f\n", ratio)
|
||||
}
|
||||
```
|
||||
|
||||
### GetUint
|
||||
|
||||
GetUint retrieves a `uint` value from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetUint(key string) (uint, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if val, ok := app.State().GetUint("maxConnections"); ok {
|
||||
fmt.Printf("Max Connections: %d\n", val)
|
||||
}
|
||||
```
|
||||
|
||||
### GetInt8
|
||||
|
||||
GetInt8 retrieves an `int8` value from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetInt8(key string) (int8, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if val, ok := app.State().GetInt8("threshold"); ok {
|
||||
fmt.Printf("Threshold: %d\n", val)
|
||||
}
|
||||
```
|
||||
|
||||
### GetInt16
|
||||
|
||||
GetInt16 retrieves an `int16` value from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetInt16(key string) (int16, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if val, ok := app.State().GetInt16("minValue"); ok {
|
||||
fmt.Printf("Minimum Value: %d\n", val)
|
||||
}
|
||||
```
|
||||
|
||||
### GetInt32
|
||||
|
||||
GetInt32 retrieves an `int32` value from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetInt32(key string) (int32, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if val, ok := app.State().GetInt32("portNumber"); ok {
|
||||
fmt.Printf("Port Number: %d\n", val)
|
||||
}
|
||||
```
|
||||
|
||||
### GetInt64
|
||||
|
||||
GetInt64 retrieves an `int64` value from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetInt64(key string) (int64, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if val, ok := app.State().GetInt64("fileSize"); ok {
|
||||
fmt.Printf("File Size: %d\n", val)
|
||||
}
|
||||
```
|
||||
|
||||
### GetUint8
|
||||
|
||||
GetUint8 retrieves a `uint8` value from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetUint8(key string) (uint8, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if val, ok := app.State().GetUint8("byteValue"); ok {
|
||||
fmt.Printf("Byte Value: %d\n", val)
|
||||
}
|
||||
```
|
||||
|
||||
### GetUint16
|
||||
|
||||
GetUint16 retrieves a `uint16` value from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetUint16(key string) (uint16, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if val, ok := app.State().GetUint16("limit"); ok {
|
||||
fmt.Printf("Limit: %d\n", val)
|
||||
}
|
||||
```
|
||||
|
||||
### GetUint32
|
||||
|
||||
GetUint32 retrieves a `uint32` value from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetUint32(key string) (uint32, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if val, ok := app.State().GetUint32("timeout"); ok {
|
||||
fmt.Printf("Timeout: %d\n", val)
|
||||
}
|
||||
```
|
||||
|
||||
### GetUint64
|
||||
|
||||
GetUint64 retrieves a `uint64` value from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetUint64(key string) (uint64, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if val, ok := app.State().GetUint64("maxSize"); ok {
|
||||
fmt.Printf("Max Size: %d\n", val)
|
||||
}
|
||||
```
|
||||
|
||||
### GetUintptr
|
||||
|
||||
GetUintptr retrieves a `uintptr` value from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetUintptr(key string) (uintptr, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if val, ok := app.State().GetUintptr("pointerValue"); ok {
|
||||
fmt.Printf("Pointer Value: %d\n", val)
|
||||
}
|
||||
```
|
||||
|
||||
### GetFloat32
|
||||
|
||||
GetFloat32 retrieves a `float32` value from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetFloat32(key string) (float32, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if val, ok := app.State().GetFloat32("scalingFactor32"); ok {
|
||||
fmt.Printf("Scaling Factor (float32): %f\n", val)
|
||||
}
|
||||
```
|
||||
|
||||
### GetComplex64
|
||||
|
||||
GetComplex64 retrieves a `complex64` value from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetComplex64(key string) (complex64, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if val, ok := app.State().GetComplex64("complexVal"); ok {
|
||||
fmt.Printf("Complex Value (complex64): %v\n", val)
|
||||
}
|
||||
```
|
||||
|
||||
### GetComplex128
|
||||
|
||||
GetComplex128 retrieves a `complex128` value from the State.
|
||||
|
||||
```go title="Signature"
|
||||
func (s *State) GetComplex128(key string) (complex128, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
if val, ok := app.State().GetComplex128("complexVal128"); ok {
|
||||
fmt.Printf("Complex Value (complex128): %v\n", val)
|
||||
}
|
||||
```
|
||||
|
||||
## Generic Functions
|
||||
|
||||
Fiber provides generic functions to retrieve state values with type safety and fallback options.
|
||||
|
||||
### GetState
|
||||
|
||||
GetState retrieves a value from the State and casts it to the desired type. It returns the cast value and a boolean indicating if the cast was successful.
|
||||
|
||||
```go title="Signature"
|
||||
func GetState[T any](s *State, key string) (T, bool)
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
// Retrieve an integer value safely.
|
||||
userCount, ok := GetState[int](app.State(), "userCount")
|
||||
if ok {
|
||||
fmt.Printf("User Count: %d\n", userCount)
|
||||
}
|
||||
```
|
||||
|
||||
### MustGetState
|
||||
|
||||
MustGetState retrieves a value from the State and casts it to the desired type. It panics if the key is not found or if the type assertion fails.
|
||||
|
||||
```go title="Signature"
|
||||
func MustGetState[T any](s *State, key string) T
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
// Retrieve the value or panic if it is not present.
|
||||
config := MustGetState[string](app.State(), "configFile")
|
||||
fmt.Println("Config File:", config)
|
||||
```
|
||||
|
||||
### GetStateWithDefault
|
||||
|
||||
GetStateWithDefault retrieves a value from the State, casting it to the desired type. If the key is not present, it returns the provided default value.
|
||||
|
||||
```go title="Signature"
|
||||
func GetStateWithDefault[T any](s *State, key string, defaultVal T) T
|
||||
```
|
||||
|
||||
**Usage Example:**
|
||||
|
||||
```go
|
||||
// Retrieve a value with a default fallback.
|
||||
requestCount := GetStateWithDefault[int](app.State(), "requestCount", 0)
|
||||
fmt.Printf("Request Count: %d\n", requestCount)
|
||||
```
|
||||
|
||||
## Comprehensive Examples
|
||||
|
||||
### Example: Request Counter
|
||||
|
||||
This example demonstrates how to track the number of requests using the State.
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gofiber/fiber/v3"
|
||||
)
|
||||
|
||||
func main() {
|
||||
app := fiber.New()
|
||||
|
||||
// Initialize state with a counter.
|
||||
app.State().Set("requestCount", 0)
|
||||
|
||||
// Middleware: Increase counter for every request.
|
||||
app.Use(func(c fiber.Ctx) error {
|
||||
count, _ := c.App().State().GetInt("requestCount")
|
||||
app.State().Set("requestCount", count+1)
|
||||
return c.Next()
|
||||
})
|
||||
|
||||
app.Get("/", func(c fiber.Ctx) error {
|
||||
return c.SendString("Hello World!")
|
||||
})
|
||||
|
||||
app.Get("/stats", func(c fiber.Ctx) error {
|
||||
count, _ := c.App().State().Get("requestCount")
|
||||
return c.SendString(fmt.Sprintf("Total requests: %d", count))
|
||||
})
|
||||
|
||||
app.Listen(":3000")
|
||||
}
|
||||
```
|
||||
|
||||
### Example: Environment–Specific Configuration
|
||||
|
||||
This example shows how to configure different settings based on the environment.
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/gofiber/fiber/v3"
|
||||
)
|
||||
|
||||
func main() {
|
||||
app := fiber.New()
|
||||
|
||||
// Determine environment.
|
||||
environment := os.Getenv("ENV")
|
||||
if environment == "" {
|
||||
environment = "development"
|
||||
}
|
||||
app.State().Set("environment", environment)
|
||||
|
||||
// Set environment-specific configurations.
|
||||
if environment == "development" {
|
||||
app.State().Set("apiUrl", "http://localhost:8080/api")
|
||||
app.State().Set("debug", true)
|
||||
} else {
|
||||
app.State().Set("apiUrl", "https://api.production.com")
|
||||
app.State().Set("debug", false)
|
||||
}
|
||||
|
||||
app.Get("/config", func(c fiber.Ctx) error {
|
||||
config := map[string]any{
|
||||
"environment": environment,
|
||||
"apiUrl": fiber.GetStateWithDefault(c.App().State(), "apiUrl", ""),
|
||||
"debug": fiber.GetStateWithDefault(c.App().State(), "debug", false),
|
||||
}
|
||||
return c.JSON(config)
|
||||
})
|
||||
|
||||
app.Listen(":3000")
|
||||
}
|
||||
```
|
||||
|
||||
### Example: Dependency Injection with State Management
|
||||
|
||||
This example demonstrates how to use the State for dependency injection in a Fiber application.
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/gofiber/fiber/v3"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID int `query:"id"`
|
||||
Name string `query:"name"`
|
||||
Email string `query:"email"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
app := fiber.New()
|
||||
ctx := context.Background()
|
||||
|
||||
// Initialize Redis client.
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: "localhost:6379",
|
||||
Password: "",
|
||||
DB: 0,
|
||||
})
|
||||
|
||||
// Check the Redis connection.
|
||||
if err := rdb.Ping(ctx).Err(); err != nil {
|
||||
log.Fatalf("Could not connect to Redis: %v", err)
|
||||
}
|
||||
|
||||
// Inject the Redis client into Fiber's State for dependency injection.
|
||||
app.State().Set("redis", rdb)
|
||||
|
||||
app.Get("/user/create", func(c fiber.Ctx) error {
|
||||
var user User
|
||||
if err := c.Bind().Query(&user); err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).SendString(err.Error())
|
||||
}
|
||||
|
||||
// Save the user to the database.
|
||||
rdb, ok := fiber.GetState[*redis.Client](c.App().State(), "redis")
|
||||
if !ok {
|
||||
return c.Status(fiber.StatusInternalServerError).SendString("Redis client not found")
|
||||
}
|
||||
|
||||
// Save the user to the database.
|
||||
key := fmt.Sprintf("user:%d", user.ID)
|
||||
err := rdb.HSet(ctx, key, "name", user.Name, "email", user.Email).Err()
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
|
||||
}
|
||||
|
||||
return c.JSON(user)
|
||||
})
|
||||
|
||||
app.Get("/user/:id", func(c fiber.Ctx) error {
|
||||
id := c.Params("id")
|
||||
|
||||
rdb, ok := fiber.GetState[*redis.Client](c.App().State(), "redis")
|
||||
if !ok {
|
||||
return c.Status(fiber.StatusInternalServerError).SendString("Redis client not found")
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("user:%s", id)
|
||||
user, err := rdb.HGetAll(ctx, key).Result()
|
||||
if err == redis.Nil {
|
||||
return c.Status(fiber.StatusNotFound).SendString("User not found")
|
||||
} else if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
|
||||
}
|
||||
|
||||
return c.JSON(user)
|
||||
})
|
||||
|
||||
app.Listen(":3000")
|
||||
}
|
||||
```
|
|
@ -27,7 +27,7 @@ Liveness, readiness and startup probes middleware for [Fiber](https://github.com
|
|||
## Signatures
|
||||
|
||||
```go
|
||||
func NewHealthChecker(config Config) fiber.Handler
|
||||
func New(config Config) fiber.Handler
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
@ -41,38 +41,44 @@ import(
|
|||
)
|
||||
```
|
||||
|
||||
After you initiate your [Fiber](https://github.com/gofiber/fiber) app, you can use the following possibilities:
|
||||
After you initiate your [Fiber](https://github.com/gofiber/fiber) app, you can use the following options:
|
||||
|
||||
```go
|
||||
// Provide a minimal config for liveness check
|
||||
app.Get(healthcheck.DefaultLivenessEndpoint, healthcheck.NewHealthChecker())
|
||||
app.Get(healthcheck.LivenessEndpoint, healthcheck.New())
|
||||
|
||||
// Provide a minimal config for readiness check
|
||||
app.Get(healthcheck.DefaultReadinessEndpoint, healthcheck.NewHealthChecker())
|
||||
app.Get(healthcheck.ReadinessEndpoint, healthcheck.New())
|
||||
|
||||
// Provide a minimal config for startup check
|
||||
app.Get(healthcheck.DefaultStartupEndpoint, healthcheck.NewHealthChecker())
|
||||
app.Get(healthcheck.StartupEndpoint, healthcheck.New())
|
||||
|
||||
// Provide a minimal config for check with custom endpoint
|
||||
app.Get("/live", healthcheck.NewHealthChecker())
|
||||
app.Get("/live", healthcheck.New())
|
||||
|
||||
// Or extend your config for customization
|
||||
app.Get(healthcheck.DefaultLivenessEndpoint, healthcheck.NewHealthChecker(healthcheck.Config{
|
||||
app.Get(healthcheck.LivenessEndpoint, healthcheck.New(healthcheck.Config{
|
||||
Probe: func(c fiber.Ctx) bool {
|
||||
return true
|
||||
},
|
||||
}))
|
||||
|
||||
// And it works the same for readiness, just change the route
|
||||
app.Get(healthcheck.DefaultReadinessEndpoint, healthcheck.NewHealthChecker(healthcheck.Config{
|
||||
app.Get(healthcheck.ReadinessEndpoint, healthcheck.New(healthcheck.Config{
|
||||
Probe: func(c fiber.Ctx) bool {
|
||||
return true
|
||||
},
|
||||
}))
|
||||
|
||||
// And it works the same for startup, just change the route
|
||||
app.Get(healthcheck.DefaultStartupEndpoint, healthcheck.NewHealthChecker(healthcheck.Config{
|
||||
app.Get(healthcheck.StartupEndpoint, healthcheck.New(healthcheck.Config{
|
||||
Probe: func(c fiber.Ctx) bool {
|
||||
return true
|
||||
},
|
||||
}))
|
||||
|
||||
// With a custom route and custom probe
|
||||
app.Get("/live", healthcheck.NewHealthChecker(healthcheck.Config{
|
||||
app.Get("/live", healthcheck.New(healthcheck.Config{
|
||||
Probe: func(c fiber.Ctx) bool {
|
||||
return true
|
||||
},
|
||||
|
@ -81,7 +87,7 @@ app.Get("/live", healthcheck.NewHealthChecker(healthcheck.Config{
|
|||
// It can also be used with app.All, although it will only respond to requests with the GET method
|
||||
// in case of calling the route with any method which isn't GET, the return will be 404 Not Found when app.All is used
|
||||
// and 405 Method Not Allowed when app.Get is used
|
||||
app.All(healthcheck.DefaultReadinessEndpoint, healthcheck.NewHealthChecker(healthcheck.Config{
|
||||
app.All(healthcheck.ReadinessEndpoint, healthcheck.New(healthcheck.Config{
|
||||
Probe: func(c fiber.Ctx) bool {
|
||||
return true
|
||||
},
|
||||
|
@ -108,7 +114,7 @@ type Config struct {
|
|||
// initialization and readiness checks
|
||||
//
|
||||
// Optional. Default: func(c fiber.Ctx) bool { return true }
|
||||
Probe HealthChecker
|
||||
Probe func(fiber.Ctx) bool
|
||||
}
|
||||
```
|
||||
|
||||
|
@ -117,7 +123,7 @@ type Config struct {
|
|||
The default configuration used by this middleware is defined as follows:
|
||||
|
||||
```go
|
||||
func defaultProbe(fiber.Ctx) bool { return true }
|
||||
func defaultProbe(_ fiber.Ctx) bool { return true }
|
||||
|
||||
var ConfigDefault = Config{
|
||||
Probe: defaultProbe,
|
||||
|
|
|
@ -59,6 +59,7 @@ We have made several changes to the Fiber app, including:
|
|||
- **RegisterCustomBinder**: Allows for the registration of custom binders.
|
||||
- **RegisterCustomConstraint**: Allows for the registration of custom constraints.
|
||||
- **NewCtxFunc**: Introduces a new context function.
|
||||
- **State**: Provides a global state for the application, which can be used to store and retrieve data across the application. Check out the [State](./api/state) method for further details.
|
||||
|
||||
### Removed Methods
|
||||
|
||||
|
@ -1564,25 +1565,25 @@ With the new version, each health check endpoint is configured separately, allow
|
|||
// after
|
||||
|
||||
// Default liveness endpoint configuration
|
||||
app.Get(healthcheck.DefaultLivenessEndpoint, healthcheck.NewHealthChecker(healthcheck.Config{
|
||||
app.Get(healthcheck.LivenessEndpoint, healthcheck.New(healthcheck.Config{
|
||||
Probe: func(c fiber.Ctx) bool {
|
||||
return true
|
||||
},
|
||||
}))
|
||||
|
||||
// Default readiness endpoint configuration
|
||||
app.Get(healthcheck.DefaultReadinessEndpoint, healthcheck.NewHealthChecker())
|
||||
app.Get(healthcheck.ReadinessEndpoint, healthcheck.New())
|
||||
|
||||
// New default startup endpoint configuration
|
||||
// Default endpoint is /startupz
|
||||
app.Get(healthcheck.DefaultStartupEndpoint, healthcheck.NewHealthChecker(healthcheck.Config{
|
||||
app.Get(healthcheck.StartupEndpoint, healthcheck.New(healthcheck.Config{
|
||||
Probe: func(c fiber.Ctx) bool {
|
||||
return serviceA.Ready() && serviceB.Ready() && ...
|
||||
},
|
||||
}))
|
||||
|
||||
// Custom liveness endpoint configuration
|
||||
app.Get("/live", healthcheck.NewHealthChecker())
|
||||
app.Get("/live", healthcheck.New())
|
||||
```
|
||||
|
||||
#### Monitor
|
||||
|
|
2
go.mod
2
go.mod
|
@ -18,7 +18,7 @@ require (
|
|||
require (
|
||||
github.com/andybalholm/brotli v1.1.1 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/fxamacker/cbor/v2 v2.7.0 // direct
|
||||
github.com/fxamacker/cbor/v2 v2.8.0 // direct
|
||||
github.com/klauspost/compress v1.17.11 // indirect
|
||||
github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
|
|
6
go.sum
6
go.sum
|
@ -2,8 +2,8 @@ github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7X
|
|||
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E=
|
||||
github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ=
|
||||
github.com/fxamacker/cbor/v2 v2.8.0 h1:fFtUGXUzXPHTIUdne5+zzMPTfffl3RD5qYnkY40vtxU=
|
||||
github.com/fxamacker/cbor/v2 v2.8.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
|
||||
github.com/gofiber/schema v1.3.0 h1:K3F3wYzAY+aivfCCEHPufCthu5/13r/lzp1nuk6mr3Q=
|
||||
github.com/gofiber/schema v1.3.0/go.mod h1:YYwj01w3hVfaNjhtJzaqetymL56VW642YS3qZPhuE6c=
|
||||
github.com/gofiber/utils/v2 v2.0.0-beta.7 h1:NnHFrRHvhrufPABdWajcKZejz9HnCWmT/asoxRsiEbQ=
|
||||
|
@ -34,8 +34,6 @@ github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZ
|
|||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
||||
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
||||
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
|
||||
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
|
|
|
@ -18,18 +18,18 @@ type Config struct {
|
|||
// the application is in a state where it can handle requests (e.g., the server is up and running).
|
||||
//
|
||||
// Optional. Default: func(c fiber.Ctx) bool { return true }
|
||||
Probe HealthChecker
|
||||
Probe func(fiber.Ctx) bool
|
||||
}
|
||||
|
||||
const (
|
||||
DefaultLivenessEndpoint = "/livez"
|
||||
DefaultReadinessEndpoint = "/readyz"
|
||||
DefaultStartupEndpoint = "/startupz"
|
||||
LivenessEndpoint = "/livez"
|
||||
ReadinessEndpoint = "/readyz"
|
||||
StartupEndpoint = "/startupz"
|
||||
)
|
||||
|
||||
func defaultProbe(fiber.Ctx) bool { return true }
|
||||
func defaultProbe(_ fiber.Ctx) bool { return true }
|
||||
|
||||
func defaultConfigV3(config ...Config) Config {
|
||||
func defaultConfig(config ...Config) Config {
|
||||
if len(config) < 1 {
|
||||
return Config{
|
||||
Probe: defaultProbe,
|
||||
|
|
|
@ -4,11 +4,8 @@ import (
|
|||
"github.com/gofiber/fiber/v3"
|
||||
)
|
||||
|
||||
// HealthChecker defines a function to check liveness or readiness of the application
|
||||
type HealthChecker func(fiber.Ctx) bool
|
||||
|
||||
func NewHealthChecker(config ...Config) fiber.Handler {
|
||||
cfg := defaultConfigV3(config...)
|
||||
func New(config ...Config) fiber.Handler {
|
||||
cfg := defaultConfig(config...)
|
||||
|
||||
return func(c fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
|
|
|
@ -34,9 +34,9 @@ func Test_HealthCheck_Strict_Routing_Default(t *testing.T) {
|
|||
StrictRouting: true,
|
||||
})
|
||||
|
||||
app.Get(DefaultLivenessEndpoint, NewHealthChecker())
|
||||
app.Get(DefaultReadinessEndpoint, NewHealthChecker())
|
||||
app.Get(DefaultStartupEndpoint, NewHealthChecker())
|
||||
app.Get(LivenessEndpoint, New())
|
||||
app.Get(ReadinessEndpoint, New())
|
||||
app.Get(StartupEndpoint, New())
|
||||
|
||||
shouldGiveOK(t, app, "/readyz")
|
||||
shouldGiveOK(t, app, "/livez")
|
||||
|
@ -53,9 +53,9 @@ func Test_HealthCheck_Default(t *testing.T) {
|
|||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
app.Get(DefaultLivenessEndpoint, NewHealthChecker())
|
||||
app.Get(DefaultReadinessEndpoint, NewHealthChecker())
|
||||
app.Get(DefaultStartupEndpoint, NewHealthChecker())
|
||||
app.Get(LivenessEndpoint, New())
|
||||
app.Get(ReadinessEndpoint, New())
|
||||
app.Get(StartupEndpoint, New())
|
||||
|
||||
shouldGiveOK(t, app, "/readyz")
|
||||
shouldGiveOK(t, app, "/livez")
|
||||
|
@ -73,12 +73,12 @@ func Test_HealthCheck_Custom(t *testing.T) {
|
|||
|
||||
app := fiber.New()
|
||||
c1 := make(chan struct{}, 1)
|
||||
app.Get("/live", NewHealthChecker(Config{
|
||||
app.Get("/live", New(Config{
|
||||
Probe: func(_ fiber.Ctx) bool {
|
||||
return true
|
||||
},
|
||||
}))
|
||||
app.Get("/ready", NewHealthChecker(Config{
|
||||
app.Get("/ready", New(Config{
|
||||
Probe: func(_ fiber.Ctx) bool {
|
||||
select {
|
||||
case <-c1:
|
||||
|
@ -88,7 +88,7 @@ func Test_HealthCheck_Custom(t *testing.T) {
|
|||
}
|
||||
},
|
||||
}))
|
||||
app.Get(DefaultStartupEndpoint, NewHealthChecker(Config{
|
||||
app.Get(StartupEndpoint, New(Config{
|
||||
Probe: func(_ fiber.Ctx) bool {
|
||||
return false
|
||||
},
|
||||
|
@ -123,12 +123,12 @@ func Test_HealthCheck_Custom_Nested(t *testing.T) {
|
|||
app := fiber.New()
|
||||
|
||||
c1 := make(chan struct{}, 1)
|
||||
app.Get("/probe/live", NewHealthChecker(Config{
|
||||
app.Get("/probe/live", New(Config{
|
||||
Probe: func(_ fiber.Ctx) bool {
|
||||
return true
|
||||
},
|
||||
}))
|
||||
app.Get("/probe/ready", NewHealthChecker(Config{
|
||||
app.Get("/probe/ready", New(Config{
|
||||
Probe: func(_ fiber.Ctx) bool {
|
||||
select {
|
||||
case <-c1:
|
||||
|
@ -164,15 +164,15 @@ func Test_HealthCheck_Next(t *testing.T) {
|
|||
|
||||
app := fiber.New()
|
||||
|
||||
checker := NewHealthChecker(Config{
|
||||
checker := New(Config{
|
||||
Next: func(_ fiber.Ctx) bool {
|
||||
return true
|
||||
},
|
||||
})
|
||||
|
||||
app.Get(DefaultLivenessEndpoint, checker)
|
||||
app.Get(DefaultReadinessEndpoint, checker)
|
||||
app.Get(DefaultStartupEndpoint, checker)
|
||||
app.Get(LivenessEndpoint, checker)
|
||||
app.Get(ReadinessEndpoint, checker)
|
||||
app.Get(StartupEndpoint, checker)
|
||||
|
||||
// This should give not found since there are no other handlers to execute
|
||||
// so it's like the route isn't defined at all
|
||||
|
@ -184,9 +184,9 @@ func Test_HealthCheck_Next(t *testing.T) {
|
|||
func Benchmark_HealthCheck(b *testing.B) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Get(DefaultLivenessEndpoint, NewHealthChecker())
|
||||
app.Get(DefaultReadinessEndpoint, NewHealthChecker())
|
||||
app.Get(DefaultStartupEndpoint, NewHealthChecker())
|
||||
app.Get(LivenessEndpoint, New())
|
||||
app.Get(ReadinessEndpoint, New())
|
||||
app.Get(StartupEndpoint, New())
|
||||
|
||||
h := app.Handler()
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
|
@ -206,9 +206,9 @@ func Benchmark_HealthCheck(b *testing.B) {
|
|||
func Benchmark_HealthCheck_Parallel(b *testing.B) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Get(DefaultLivenessEndpoint, NewHealthChecker())
|
||||
app.Get(DefaultReadinessEndpoint, NewHealthChecker())
|
||||
app.Get(DefaultStartupEndpoint, NewHealthChecker())
|
||||
app.Get(LivenessEndpoint, New())
|
||||
app.Get(ReadinessEndpoint, New())
|
||||
app.Get(StartupEndpoint, New())
|
||||
|
||||
h := app.Handler()
|
||||
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
package keyauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
@ -60,10 +59,7 @@ func New(config ...Config) fiber.Handler {
|
|||
valid, err := cfg.Validator(c, key)
|
||||
|
||||
if err == nil && valid {
|
||||
// Store in both Locals and Context
|
||||
c.Locals(tokenKey, key)
|
||||
ctx := context.WithValue(c.Context(), tokenKey, key)
|
||||
c.SetContext(ctx)
|
||||
return cfg.SuccessHandler(c)
|
||||
}
|
||||
return cfg.ErrorHandler(c, err)
|
||||
|
@ -72,20 +68,12 @@ func New(config ...Config) fiber.Handler {
|
|||
|
||||
// TokenFromContext returns the bearer token from the request context.
|
||||
// returns an empty string if the token does not exist
|
||||
func TokenFromContext(c any) string {
|
||||
switch ctx := c.(type) {
|
||||
case context.Context:
|
||||
if token, ok := ctx.Value(tokenKey).(string); ok {
|
||||
return token
|
||||
}
|
||||
case fiber.Ctx:
|
||||
if token, ok := ctx.Locals(tokenKey).(string); ok {
|
||||
return token
|
||||
}
|
||||
default:
|
||||
panic("unsupported context type, expected fiber.Ctx or context.Context")
|
||||
func TokenFromContext(c fiber.Ctx) string {
|
||||
token, ok := c.Locals(tokenKey).(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return ""
|
||||
return token
|
||||
}
|
||||
|
||||
// MultipleKeySourceLookup creates a CustomKeyLookup function that checks multiple sources until one is found
|
||||
|
|
|
@ -503,67 +503,33 @@ func Test_TokenFromContext_None(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_TokenFromContext(t *testing.T) {
|
||||
// Test that TokenFromContext returns the correct token
|
||||
t.Run("fiber.Ctx", func(t *testing.T) {
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
KeyLookup: "header:Authorization",
|
||||
AuthScheme: "Basic",
|
||||
Validator: func(_ fiber.Ctx, key string) (bool, error) {
|
||||
if key == CorrectKey {
|
||||
return true, nil
|
||||
}
|
||||
return false, ErrMissingOrMalformedAPIKey
|
||||
},
|
||||
}))
|
||||
app.Get("/", func(c fiber.Ctx) error {
|
||||
return c.SendString(TokenFromContext(c))
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
req.Header.Add("Authorization", "Basic "+CorrectKey)
|
||||
res, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, CorrectKey, string(body))
|
||||
app := fiber.New()
|
||||
// Wire up keyauth middleware to set TokenFromContext now
|
||||
app.Use(New(Config{
|
||||
KeyLookup: "header:Authorization",
|
||||
AuthScheme: "Basic",
|
||||
Validator: func(_ fiber.Ctx, key string) (bool, error) {
|
||||
if key == CorrectKey {
|
||||
return true, nil
|
||||
}
|
||||
return false, ErrMissingOrMalformedAPIKey
|
||||
},
|
||||
}))
|
||||
// Define a test handler that checks TokenFromContext
|
||||
app.Get("/", func(c fiber.Ctx) error {
|
||||
return c.SendString(TokenFromContext(c))
|
||||
})
|
||||
|
||||
t.Run("context.Context", func(t *testing.T) {
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
KeyLookup: "header:Authorization",
|
||||
AuthScheme: "Basic",
|
||||
Validator: func(_ fiber.Ctx, key string) (bool, error) {
|
||||
if key == CorrectKey {
|
||||
return true, nil
|
||||
}
|
||||
return false, ErrMissingOrMalformedAPIKey
|
||||
},
|
||||
}))
|
||||
// Verify that TokenFromContext works with context.Context
|
||||
app.Get("/", func(c fiber.Ctx) error {
|
||||
ctx := c.Context()
|
||||
token := TokenFromContext(ctx)
|
||||
return c.SendString(token)
|
||||
})
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
req.Header.Add("Authorization", "Basic "+CorrectKey)
|
||||
// Send
|
||||
res, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
req.Header.Add("Authorization", "Basic "+CorrectKey)
|
||||
res, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, CorrectKey, string(body))
|
||||
})
|
||||
|
||||
t.Run("invalid context type", func(t *testing.T) {
|
||||
require.Panics(t, func() {
|
||||
_ = TokenFromContext("invalid")
|
||||
})
|
||||
})
|
||||
// Read the response body into a string
|
||||
body, err := io.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, CorrectKey, string(body))
|
||||
}
|
||||
|
||||
func Test_AuthSchemeToken(t *testing.T) {
|
||||
|
|
46
path.go
46
path.go
|
@ -123,16 +123,6 @@ var (
|
|||
parameterDelimiterChars = append([]byte{paramStarterChar, escapeChar}, routeDelimiter...)
|
||||
// list of chars to find the end of a parameter
|
||||
parameterEndChars = append([]byte{optionalParam}, parameterDelimiterChars...)
|
||||
// list of parameter constraint start
|
||||
parameterConstraintStartChars = []byte{paramConstraintStart}
|
||||
// list of parameter constraint end
|
||||
parameterConstraintEndChars = []byte{paramConstraintEnd}
|
||||
// list of parameter separator
|
||||
parameterConstraintSeparatorChars = []byte{paramConstraintSeparator}
|
||||
// list of parameter constraint data start
|
||||
parameterConstraintDataStartChars = []byte{paramConstraintDataStart}
|
||||
// list of parameter constraint data separator
|
||||
parameterConstraintDataSeparatorChars = []byte{paramConstraintDataSeparator}
|
||||
)
|
||||
|
||||
// RoutePatternMatch checks if a given path matches a Fiber route pattern.
|
||||
|
@ -337,8 +327,8 @@ func (parser *routeParser) analyseParameterPart(pattern string, customConstraint
|
|||
|
||||
// find constraint part if exists in the parameter part and remove it
|
||||
if parameterEndPosition > 0 {
|
||||
parameterConstraintStart = findNextNonEscapedCharsetPosition(pattern[0:parameterEndPosition], parameterConstraintStartChars)
|
||||
parameterConstraintEnd = strings.LastIndexByte(pattern[0:parameterEndPosition+1], paramConstraintEnd)
|
||||
parameterConstraintStart = findNextNonEscapedCharPosition(pattern[:parameterEndPosition], paramConstraintStart)
|
||||
parameterConstraintEnd = strings.LastIndexByte(pattern[:parameterEndPosition+1], paramConstraintEnd)
|
||||
}
|
||||
|
||||
// cut params part
|
||||
|
@ -351,11 +341,11 @@ func (parser *routeParser) analyseParameterPart(pattern string, customConstraint
|
|||
|
||||
if hasConstraint := parameterConstraintStart != -1 && parameterConstraintEnd != -1; hasConstraint {
|
||||
constraintString := pattern[parameterConstraintStart+1 : parameterConstraintEnd]
|
||||
userConstraints := splitNonEscaped(constraintString, string(parameterConstraintSeparatorChars))
|
||||
userConstraints := splitNonEscaped(constraintString, paramConstraintSeparator)
|
||||
constraints = make([]*Constraint, 0, len(userConstraints))
|
||||
|
||||
for _, c := range userConstraints {
|
||||
start := findNextNonEscapedCharsetPosition(c, parameterConstraintDataStartChars)
|
||||
start := findNextNonEscapedCharPosition(c, paramConstraintDataStart)
|
||||
end := strings.LastIndexByte(c, paramConstraintDataEnd)
|
||||
|
||||
// Assign constraint
|
||||
|
@ -368,7 +358,7 @@ func (parser *routeParser) analyseParameterPart(pattern string, customConstraint
|
|||
|
||||
// remove escapes from data
|
||||
if constraint.ID != regexConstraint {
|
||||
constraint.Data = splitNonEscaped(c[start+1:end], string(parameterConstraintDataSeparatorChars))
|
||||
constraint.Data = splitNonEscaped(c[start+1:end], paramConstraintDataSeparator)
|
||||
if len(constraint.Data) == 1 {
|
||||
constraint.Data[0] = RemoveEscapeChar(constraint.Data[0])
|
||||
} else if len(constraint.Data) == 2 { // This is fine, we simply expect two parts
|
||||
|
@ -432,11 +422,11 @@ func findNextCharsetPosition(search string, charset []byte) int {
|
|||
return nextPosition
|
||||
}
|
||||
|
||||
// findNextCharsetPositionConstraint search the next char position from the charset
|
||||
// findNextCharsetPositionConstraint searches the next char position from the charset
|
||||
// unlike findNextCharsetPosition, it takes care of constraint start-end chars to parse route pattern
|
||||
func findNextCharsetPositionConstraint(search string, charset []byte) int {
|
||||
constraintStart := findNextNonEscapedCharsetPosition(search, parameterConstraintStartChars)
|
||||
constraintEnd := findNextNonEscapedCharsetPosition(search, parameterConstraintEndChars)
|
||||
constraintStart := findNextNonEscapedCharPosition(search, paramConstraintStart)
|
||||
constraintEnd := findNextNonEscapedCharPosition(search, paramConstraintEnd)
|
||||
nextPosition := -1
|
||||
|
||||
for _, char := range charset {
|
||||
|
@ -452,7 +442,7 @@ func findNextCharsetPositionConstraint(search string, charset []byte) int {
|
|||
return nextPosition
|
||||
}
|
||||
|
||||
// findNextNonEscapedCharsetPosition search the next char position from the charset and skip the escaped characters
|
||||
// findNextNonEscapedCharsetPosition searches the next char position from the charset and skips the escaped characters
|
||||
func findNextNonEscapedCharsetPosition(search string, charset []byte) int {
|
||||
pos := findNextCharsetPosition(search, charset)
|
||||
for pos > 0 && search[pos-1] == escapeChar {
|
||||
|
@ -471,16 +461,26 @@ func findNextNonEscapedCharsetPosition(search string, charset []byte) int {
|
|||
return pos
|
||||
}
|
||||
|
||||
// findNextNonEscapedCharPosition searches the next char position and skips the escaped characters
|
||||
func findNextNonEscapedCharPosition(search string, char byte) int {
|
||||
for i := 0; i < len(search); i++ {
|
||||
if search[i] == char && (i == 0 || search[i-1] != escapeChar) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// splitNonEscaped slices s into all substrings separated by sep and returns a slice of the substrings between those separators
|
||||
// This function also takes a care of escape char when splitting.
|
||||
func splitNonEscaped(s, sep string) []string {
|
||||
func splitNonEscaped(s string, sep byte) []string {
|
||||
var result []string
|
||||
i := findNextNonEscapedCharsetPosition(s, []byte(sep))
|
||||
i := findNextNonEscapedCharPosition(s, sep)
|
||||
|
||||
for i > -1 {
|
||||
result = append(result, s[:i])
|
||||
s = s[i+len(sep):]
|
||||
i = findNextNonEscapedCharsetPosition(s, []byte(sep))
|
||||
s = s[i+1:]
|
||||
i = findNextNonEscapedCharPosition(s, sep)
|
||||
}
|
||||
|
||||
return append(result, s)
|
||||
|
|
|
@ -0,0 +1,322 @@
|
|||
package fiber
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// State is a key-value store for Fiber's app in order to be used as a global storage for the app's dependencies.
|
||||
// It's a thread-safe implementation of a map[string]any, using sync.Map.
|
||||
type State struct {
|
||||
dependencies sync.Map
|
||||
}
|
||||
|
||||
// NewState creates a new instance of State.
|
||||
func newState() *State {
|
||||
return &State{
|
||||
dependencies: sync.Map{},
|
||||
}
|
||||
}
|
||||
|
||||
// Set sets a key-value pair in the State.
|
||||
func (s *State) Set(key string, value any) {
|
||||
s.dependencies.Store(key, value)
|
||||
}
|
||||
|
||||
// Get retrieves a value from the State.
|
||||
func (s *State) Get(key string) (any, bool) {
|
||||
return s.dependencies.Load(key)
|
||||
}
|
||||
|
||||
// MustGet retrieves a value from the State and panics if the key is not found.
|
||||
func (s *State) MustGet(key string) any {
|
||||
if dep, ok := s.Get(key); ok {
|
||||
return dep
|
||||
}
|
||||
|
||||
panic("state: dependency not found!")
|
||||
}
|
||||
|
||||
// Has checks if a key is present in the State.
|
||||
// It returns a boolean indicating if the key is present.
|
||||
func (s *State) Has(key string) bool {
|
||||
_, ok := s.Get(key)
|
||||
return ok
|
||||
}
|
||||
|
||||
// Delete removes a key-value pair from the State.
|
||||
func (s *State) Delete(key string) {
|
||||
s.dependencies.Delete(key)
|
||||
}
|
||||
|
||||
// Reset resets the State by removing all keys.
|
||||
func (s *State) Reset() {
|
||||
s.dependencies.Clear()
|
||||
}
|
||||
|
||||
// Keys returns a slice containing all keys present in the State.
|
||||
func (s *State) Keys() []string {
|
||||
keys := make([]string, 0)
|
||||
s.dependencies.Range(func(key, _ any) bool {
|
||||
keyStr, ok := key.(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
keys = append(keys, keyStr)
|
||||
return true
|
||||
})
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
// Len returns the number of keys in the State.
|
||||
func (s *State) Len() int {
|
||||
length := 0
|
||||
s.dependencies.Range(func(_, _ any) bool {
|
||||
length++
|
||||
return true
|
||||
})
|
||||
|
||||
return length
|
||||
}
|
||||
|
||||
// GetState retrieves a value from the State and casts it to the desired type.
|
||||
// It returns the casted value and a boolean indicating if the cast was successful.
|
||||
func GetState[T any](s *State, key string) (T, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
|
||||
if ok {
|
||||
depT, okCast := dep.(T)
|
||||
return depT, okCast
|
||||
}
|
||||
|
||||
var zeroVal T
|
||||
return zeroVal, false
|
||||
}
|
||||
|
||||
// MustGetState retrieves a value from the State and casts it to the desired type.
|
||||
// It panics if the key is not found or if the type assertion fails.
|
||||
func MustGetState[T any](s *State, key string) T {
|
||||
dep, ok := GetState[T](s, key)
|
||||
if !ok {
|
||||
panic("state: dependency not found!")
|
||||
}
|
||||
|
||||
return dep
|
||||
}
|
||||
|
||||
// GetStateWithDefault retrieves a value from the State,
|
||||
// casting it to the desired type. If the key is not present,
|
||||
// it returns the provided default value.
|
||||
func GetStateWithDefault[T any](s *State, key string, defaultVal T) T {
|
||||
dep, ok := GetState[T](s, key)
|
||||
if !ok {
|
||||
return defaultVal
|
||||
}
|
||||
|
||||
return dep
|
||||
}
|
||||
|
||||
// GetString retrieves a string value from the State.
|
||||
// It returns the string and a boolean indicating successful type assertion.
|
||||
func (s *State) GetString(key string) (string, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
depString, okCast := dep.(string)
|
||||
return depString, okCast
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
// GetInt retrieves an integer value from the State.
|
||||
// It returns the int and a boolean indicating successful type assertion.
|
||||
func (s *State) GetInt(key string) (int, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
depInt, okCast := dep.(int)
|
||||
return depInt, okCast
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetBool retrieves a boolean value from the State.
|
||||
// It returns the bool and a boolean indicating successful type assertion.
|
||||
func (s *State) GetBool(key string) (value, ok bool) { //nolint:nonamedreturns // Better idea to use named returns here
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
depBool, okCast := dep.(bool)
|
||||
return depBool, okCast
|
||||
}
|
||||
|
||||
return false, false
|
||||
}
|
||||
|
||||
// GetFloat64 retrieves a float64 value from the State.
|
||||
// It returns the float64 and a boolean indicating successful type assertion.
|
||||
func (s *State) GetFloat64(key string) (float64, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
depFloat64, okCast := dep.(float64)
|
||||
return depFloat64, okCast
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetUint retrieves a uint value from the State.
|
||||
// It returns the float64 and a boolean indicating successful type assertion.
|
||||
func (s *State) GetUint(key string) (uint, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
if depUint, okCast := dep.(uint); okCast {
|
||||
return depUint, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetInt8 retrieves an int8 value from the State.
|
||||
// It returns the float64 and a boolean indicating successful type assertion.
|
||||
func (s *State) GetInt8(key string) (int8, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
if depInt8, okCast := dep.(int8); okCast {
|
||||
return depInt8, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetInt16 retrieves an int16 value from the State.
|
||||
// It returns the float64 and a boolean indicating successful type assertion.
|
||||
func (s *State) GetInt16(key string) (int16, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
if depInt16, okCast := dep.(int16); okCast {
|
||||
return depInt16, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetInt32 retrieves an int32 value from the State.
|
||||
// It returns the float64 and a boolean indicating successful type assertion.
|
||||
func (s *State) GetInt32(key string) (int32, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
if depInt32, okCast := dep.(int32); okCast {
|
||||
return depInt32, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetInt64 retrieves an int64 value from the State.
|
||||
// It returns the float64 and a boolean indicating successful type assertion.
|
||||
func (s *State) GetInt64(key string) (int64, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
if depInt64, okCast := dep.(int64); okCast {
|
||||
return depInt64, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetUint8 retrieves a uint8 value from the State.
|
||||
// It returns the float64 and a boolean indicating successful type assertion.
|
||||
func (s *State) GetUint8(key string) (uint8, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
if depUint8, okCast := dep.(uint8); okCast {
|
||||
return depUint8, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetUint16 retrieves a uint16 value from the State.
|
||||
// It returns the float64 and a boolean indicating successful type assertion.
|
||||
func (s *State) GetUint16(key string) (uint16, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
if depUint16, okCast := dep.(uint16); okCast {
|
||||
return depUint16, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetUint32 retrieves a uint32 value from the State.
|
||||
// It returns the float64 and a boolean indicating successful type assertion.
|
||||
func (s *State) GetUint32(key string) (uint32, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
if depUint32, okCast := dep.(uint32); okCast {
|
||||
return depUint32, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetUint64 retrieves a uint64 value from the State.
|
||||
// It returns the float64 and a boolean indicating successful type assertion.
|
||||
func (s *State) GetUint64(key string) (uint64, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
if depUint64, okCast := dep.(uint64); okCast {
|
||||
return depUint64, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetUintptr retrieves a uintptr value from the State.
|
||||
// It returns the float64 and a boolean indicating successful type assertion.
|
||||
func (s *State) GetUintptr(key string) (uintptr, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
if depUintptr, okCast := dep.(uintptr); okCast {
|
||||
return depUintptr, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetFloat32 retrieves a float32 value from the State.
|
||||
// It returns the float64 and a boolean indicating successful type assertion.
|
||||
func (s *State) GetFloat32(key string) (float32, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
if depFloat32, okCast := dep.(float32); okCast {
|
||||
return depFloat32, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetComplex64 retrieves a complex64 value from the State.
|
||||
// It returns the float64 and a boolean indicating successful type assertion.
|
||||
func (s *State) GetComplex64(key string) (complex64, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
if depComplex64, okCast := dep.(complex64); okCast {
|
||||
return depComplex64, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetComplex128 retrieves a complex128 value from the State.
|
||||
// It returns the float64 and a boolean indicating successful type assertion.
|
||||
func (s *State) GetComplex128(key string) (complex128, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
if depComplex128, okCast := dep.(complex128); okCast {
|
||||
return depComplex128, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
|
@ -0,0 +1,981 @@
|
|||
package fiber
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestState_SetAndGet_WithApp(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Create app
|
||||
app := New()
|
||||
|
||||
// test setting and getting a value
|
||||
app.State().Set("foo", "bar")
|
||||
val, ok := app.State().Get("foo")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "bar", val)
|
||||
|
||||
// test key not found
|
||||
_, ok = app.State().Get("unknown")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestState_SetAndGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
// test setting and getting a value
|
||||
st.Set("foo", "bar")
|
||||
val, ok := st.Get("foo")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "bar", val)
|
||||
|
||||
// test key not found
|
||||
_, ok = st.Get("unknown")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestState_GetString(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("str", "hello")
|
||||
s, ok := st.GetString("str")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "hello", s)
|
||||
|
||||
// wrong type should return false
|
||||
st.Set("num", 123)
|
||||
s, ok = st.GetString("num")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, "", s)
|
||||
|
||||
// missing key should return false
|
||||
s, ok = st.GetString("missing")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, "", s)
|
||||
}
|
||||
|
||||
func TestState_GetInt(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("num", 456)
|
||||
i, ok := st.GetInt("num")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, 456, i)
|
||||
|
||||
// wrong type should return zero value
|
||||
st.Set("str", "abc")
|
||||
i, ok = st.GetInt("str")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, 0, i)
|
||||
|
||||
// missing key should return zero value
|
||||
i, ok = st.GetInt("missing")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, 0, i)
|
||||
}
|
||||
|
||||
func TestState_GetBool(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("flag", true)
|
||||
b, ok := st.GetBool("flag")
|
||||
require.True(t, ok)
|
||||
require.True(t, b)
|
||||
|
||||
// wrong type
|
||||
st.Set("num", 1)
|
||||
b, ok = st.GetBool("num")
|
||||
require.False(t, ok)
|
||||
require.False(t, b)
|
||||
|
||||
// missing key should return false
|
||||
b, ok = st.GetBool("missing")
|
||||
require.False(t, ok)
|
||||
require.False(t, b)
|
||||
}
|
||||
|
||||
func TestState_GetFloat64(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("pi", 3.14)
|
||||
f, ok := st.GetFloat64("pi")
|
||||
require.True(t, ok)
|
||||
require.InDelta(t, 3.14, f, 0.0001)
|
||||
|
||||
// wrong type should return zero value
|
||||
st.Set("int", 10)
|
||||
f, ok = st.GetFloat64("int")
|
||||
require.False(t, ok)
|
||||
require.InDelta(t, 0.0, f, 0.0001)
|
||||
|
||||
// missing key should return zero value
|
||||
f, ok = st.GetFloat64("missing")
|
||||
require.False(t, ok)
|
||||
require.InDelta(t, 0.0, f, 0.0001)
|
||||
}
|
||||
|
||||
func TestState_GetUint(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("uint", uint(100))
|
||||
u, ok := st.GetUint("uint")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, uint(100), u)
|
||||
|
||||
st.Set("wrong", "not uint")
|
||||
u, ok = st.GetUint("wrong")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uint(0), u)
|
||||
|
||||
u, ok = st.GetUint("missing")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uint(0), u)
|
||||
}
|
||||
|
||||
func TestState_GetInt8(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("int8", int8(10))
|
||||
i, ok := st.GetInt8("int8")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, int8(10), i)
|
||||
|
||||
st.Set("wrong", "not int8")
|
||||
i, ok = st.GetInt8("wrong")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, int8(0), i)
|
||||
|
||||
i, ok = st.GetInt8("missing")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, int8(0), i)
|
||||
}
|
||||
|
||||
func TestState_GetInt16(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("int16", int16(200))
|
||||
i, ok := st.GetInt16("int16")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, int16(200), i)
|
||||
|
||||
st.Set("wrong", "not int16")
|
||||
i, ok = st.GetInt16("wrong")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, int16(0), i)
|
||||
|
||||
i, ok = st.GetInt16("missing")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, int16(0), i)
|
||||
}
|
||||
|
||||
func TestState_GetInt32(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("int32", int32(3000))
|
||||
i, ok := st.GetInt32("int32")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, int32(3000), i)
|
||||
|
||||
st.Set("wrong", "not int32")
|
||||
i, ok = st.GetInt32("wrong")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, int32(0), i)
|
||||
|
||||
i, ok = st.GetInt32("missing")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, int32(0), i)
|
||||
}
|
||||
|
||||
func TestState_GetInt64(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("int64", int64(4000))
|
||||
i, ok := st.GetInt64("int64")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, int64(4000), i)
|
||||
|
||||
st.Set("wrong", "not int64")
|
||||
i, ok = st.GetInt64("wrong")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, int64(0), i)
|
||||
|
||||
i, ok = st.GetInt64("missing")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, int64(0), i)
|
||||
}
|
||||
|
||||
func TestState_GetUint8(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("uint8", uint8(20))
|
||||
u, ok := st.GetUint8("uint8")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, uint8(20), u)
|
||||
|
||||
st.Set("wrong", "not uint8")
|
||||
u, ok = st.GetUint8("wrong")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uint8(0), u)
|
||||
|
||||
u, ok = st.GetUint8("missing")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uint8(0), u)
|
||||
}
|
||||
|
||||
func TestState_GetUint16(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("uint16", uint16(300))
|
||||
u, ok := st.GetUint16("uint16")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, uint16(300), u)
|
||||
|
||||
st.Set("wrong", "not uint16")
|
||||
u, ok = st.GetUint16("wrong")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uint16(0), u)
|
||||
|
||||
u, ok = st.GetUint16("missing")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uint16(0), u)
|
||||
}
|
||||
|
||||
func TestState_GetUint32(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("uint32", uint32(400000))
|
||||
u, ok := st.GetUint32("uint32")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, uint32(400000), u)
|
||||
|
||||
st.Set("wrong", "not uint32")
|
||||
u, ok = st.GetUint32("wrong")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uint32(0), u)
|
||||
|
||||
u, ok = st.GetUint32("missing")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uint32(0), u)
|
||||
}
|
||||
|
||||
func TestState_GetUint64(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("uint64", uint64(5000000))
|
||||
u, ok := st.GetUint64("uint64")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, uint64(5000000), u)
|
||||
|
||||
st.Set("wrong", "not uint64")
|
||||
u, ok = st.GetUint64("wrong")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uint64(0), u)
|
||||
|
||||
u, ok = st.GetUint64("missing")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uint64(0), u)
|
||||
}
|
||||
|
||||
func TestState_GetUintptr(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
var ptr uintptr = 12345
|
||||
st.Set("uintptr", ptr)
|
||||
u, ok := st.GetUintptr("uintptr")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, ptr, u)
|
||||
|
||||
st.Set("wrong", "not uintptr")
|
||||
u, ok = st.GetUintptr("wrong")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uintptr(0), u)
|
||||
|
||||
u, ok = st.GetUintptr("missing")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uintptr(0), u)
|
||||
}
|
||||
|
||||
func TestState_GetFloat32(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("float32", float32(3.14))
|
||||
f, ok := st.GetFloat32("float32")
|
||||
require.True(t, ok)
|
||||
require.InDelta(t, float32(3.14), f, 0.0001)
|
||||
|
||||
st.Set("wrong", "not float32")
|
||||
f, ok = st.GetFloat32("wrong")
|
||||
require.False(t, ok)
|
||||
require.InDelta(t, float32(0), f, 0.0001)
|
||||
|
||||
f, ok = st.GetFloat32("missing")
|
||||
require.False(t, ok)
|
||||
require.InDelta(t, float32(0), f, 0.0001)
|
||||
}
|
||||
|
||||
func TestState_GetComplex64(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
var c complex64 = complex(2, 3)
|
||||
st.Set("complex64", c)
|
||||
cRes, ok := st.GetComplex64("complex64")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, c, cRes)
|
||||
|
||||
st.Set("wrong", "not complex64")
|
||||
cRes, ok = st.GetComplex64("wrong")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, complex64(0), cRes)
|
||||
|
||||
cRes, ok = st.GetComplex64("missing")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, complex64(0), cRes)
|
||||
}
|
||||
|
||||
func TestState_GetComplex128(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
c := complex(4, 5)
|
||||
st.Set("complex128", c)
|
||||
cRes, ok := st.GetComplex128("complex128")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, c, cRes)
|
||||
|
||||
st.Set("wrong", "not complex128")
|
||||
cRes, ok = st.GetComplex128("wrong")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, complex128(0), cRes)
|
||||
|
||||
cRes, ok = st.GetComplex128("missing")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, complex128(0), cRes)
|
||||
}
|
||||
|
||||
func TestState_MustGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("exists", "value")
|
||||
val := st.MustGet("exists")
|
||||
require.Equal(t, "value", val)
|
||||
|
||||
// must-get on missing key should panic
|
||||
require.Panics(t, func() {
|
||||
_ = st.MustGet("missing")
|
||||
})
|
||||
}
|
||||
|
||||
func TestState_Has(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
st := newState()
|
||||
|
||||
st.Set("key", "value")
|
||||
require.True(t, st.Has("key"))
|
||||
}
|
||||
|
||||
func TestState_Delete(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("key", "value")
|
||||
st.Delete("key")
|
||||
_, ok := st.Get("key")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestState_Reset(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("a", 1)
|
||||
st.Set("b", 2)
|
||||
st.Reset()
|
||||
require.Equal(t, 0, st.Len())
|
||||
require.Empty(t, st.Keys())
|
||||
}
|
||||
|
||||
func TestState_Keys(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
keys := []string{"one", "two", "three"}
|
||||
for _, k := range keys {
|
||||
st.Set(k, k)
|
||||
}
|
||||
|
||||
returnedKeys := st.Keys()
|
||||
require.ElementsMatch(t, keys, returnedKeys)
|
||||
}
|
||||
|
||||
func TestState_Len(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
require.Equal(t, 0, st.Len())
|
||||
|
||||
st.Set("a", "a")
|
||||
require.Equal(t, 1, st.Len())
|
||||
|
||||
st.Set("b", "b")
|
||||
require.Equal(t, 2, st.Len())
|
||||
|
||||
st.Delete("a")
|
||||
require.Equal(t, 1, st.Len())
|
||||
}
|
||||
|
||||
type testCase[T any] struct { //nolint:govet // It does not really matter for test
|
||||
name string
|
||||
key string
|
||||
value any
|
||||
expected T
|
||||
ok bool
|
||||
}
|
||||
|
||||
func runGenericTest[T any](t *testing.T, getter func(*State, string) (T, bool), tests []testCase[T]) {
|
||||
t.Helper()
|
||||
|
||||
st := newState()
|
||||
for _, tc := range tests {
|
||||
st.Set(tc.key, tc.value)
|
||||
got, ok := getter(st, tc.key)
|
||||
require.Equal(t, tc.ok, ok, tc.name)
|
||||
require.Equal(t, tc.expected, got, tc.name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestState_GetGeneric(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runGenericTest[int](t, GetState[int], []testCase[int]{
|
||||
{"int correct conversion", "num", 42, 42, true},
|
||||
{"int wrong conversion from string", "str", "abc", 0, false},
|
||||
})
|
||||
|
||||
runGenericTest[string](t, GetState[string], []testCase[string]{
|
||||
{"string correct conversion", "strVal", "hello", "hello", true},
|
||||
{"string wrong conversion from int", "intVal", 100, "", false},
|
||||
})
|
||||
|
||||
runGenericTest[bool](t, GetState[bool], []testCase[bool]{
|
||||
{"bool correct conversion", "flag", true, true, true},
|
||||
{"bool wrong conversion from int", "intFlag", 1, false, false},
|
||||
})
|
||||
|
||||
runGenericTest[float64](t, GetState[float64], []testCase[float64]{
|
||||
{"float64 correct conversion", "pi", 3.14, 3.14, true},
|
||||
{"float64 wrong conversion from int", "intVal", 10, 0.0, false},
|
||||
})
|
||||
}
|
||||
|
||||
func Test_MustGetStateGeneric(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("flag", true)
|
||||
flag := MustGetState[bool](st, "flag")
|
||||
require.True(t, flag)
|
||||
|
||||
// mismatched type should panic
|
||||
require.Panics(t, func() {
|
||||
_ = MustGetState[string](st, "flag")
|
||||
})
|
||||
|
||||
// missing key should also panic
|
||||
require.Panics(t, func() {
|
||||
_ = MustGetState[string](st, "missing")
|
||||
})
|
||||
}
|
||||
|
||||
func Test_GetStateWithDefault(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("flag", true)
|
||||
flag := GetStateWithDefault(st, "flag", false)
|
||||
require.True(t, flag)
|
||||
|
||||
// mismatched type should return the default value
|
||||
str := GetStateWithDefault(st, "flag", "default")
|
||||
require.Equal(t, "default", str)
|
||||
|
||||
// missing key should return the default value
|
||||
flag = GetStateWithDefault(st, "missing", false)
|
||||
require.False(t, flag)
|
||||
}
|
||||
|
||||
func BenchmarkState_Set(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, i)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_Get(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
// pre-populate the state
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.Get(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetString(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
// pre-populate the state
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, strconv.Itoa(i))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetString(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetInt(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
// pre-populate the state
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetInt(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetBool(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
// pre-populate the state
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, i%2 == 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetBool(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetFloat64(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
// pre-populate the state
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, float64(i))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetFloat64(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_MustGet(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
// pre-populate the state
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.MustGet(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetStateGeneric(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
// pre-populate the state
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
GetState[int](st, key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_MustGetStateGeneric(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
// pre-populate the state
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
MustGetState[int](st, key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetStateWithDefault(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
// pre-populate the state
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
GetStateWithDefault[int](st, key, 0)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_Has(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
// pre-populate the state
|
||||
for i := 0; i < n; i++ {
|
||||
st.Set("key"+strconv.Itoa(i), i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
st.Has("key" + strconv.Itoa(i%n))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_Delete(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
st := newState()
|
||||
st.Set("a", 1)
|
||||
st.Delete("a")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_Reset(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
st := newState()
|
||||
// add a fixed number of keys before clearing
|
||||
for j := 0; j < 100; j++ {
|
||||
st.Set("key"+strconv.Itoa(j), j)
|
||||
}
|
||||
st.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_Keys(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
for i := 0; i < n; i++ {
|
||||
st.Set("key"+strconv.Itoa(i), i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = st.Keys()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_Len(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
for i := 0; i < n; i++ {
|
||||
st.Set("key"+strconv.Itoa(i), i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = st.Len()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetUint(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
st := newState()
|
||||
n := 1000
|
||||
// Pre-populate the state with uint values.
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, uint(i)) //nolint:gosec // This is a test
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetUint(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetInt8(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
st := newState()
|
||||
n := 1000
|
||||
// Pre-populate the state with int8 values (using modulo to stay in range).
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, int8(i%128)) //nolint:gosec // This is a test
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetInt8(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetInt16(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
st := newState()
|
||||
n := 1000
|
||||
// Pre-populate the state with int16 values.
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, int16(i)) //nolint:gosec // This is a test
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetInt16(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetInt32(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
st := newState()
|
||||
n := 1000
|
||||
// Pre-populate the state with int32 values.
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, int32(i)) //nolint:gosec // This is a test
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetInt32(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetInt64(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
st := newState()
|
||||
n := 1000
|
||||
// Pre-populate the state with int64 values.
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, int64(i))
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetInt64(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetUint8(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
st := newState()
|
||||
n := 1000
|
||||
// Pre-populate the state with uint8 values.
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, uint8(i%256)) //nolint:gosec // This is a test
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetUint8(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetUint16(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
st := newState()
|
||||
n := 1000
|
||||
// Pre-populate the state with uint16 values.
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, uint16(i)) //nolint:gosec // This is a test
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetUint16(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetUint32(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
st := newState()
|
||||
n := 1000
|
||||
// Pre-populate the state with uint32 values.
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, uint32(i)) //nolint:gosec // This is a test
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetUint32(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetUint64(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
st := newState()
|
||||
n := 1000
|
||||
// Pre-populate the state with uint64 values.
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, uint64(i)) //nolint:gosec // This is a test
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetUint64(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetUintptr(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
st := newState()
|
||||
n := 1000
|
||||
// Pre-populate the state with uintptr values.
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, uintptr(i))
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetUintptr(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetFloat32(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
st := newState()
|
||||
n := 1000
|
||||
// Pre-populate the state with float32 values.
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, float32(i))
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetFloat32(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetComplex64(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
st := newState()
|
||||
n := 1000
|
||||
// Pre-populate the state with complex64 values.
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
// Create a complex64 value with both real and imaginary parts.
|
||||
st.Set(key, complex(float32(i), float32(i)))
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetComplex64(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetComplex128(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
st := newState()
|
||||
n := 1000
|
||||
// Pre-populate the state with complex128 values.
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
// Create a complex128 value with both real and imaginary parts.
|
||||
st.Set(key, complex(float64(i), float64(i)))
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetComplex128(key)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue