mirror of https://github.com/gofiber/fiber.git
✨ feat: add support for application state management
parent
395c8fafa9
commit
655677e8bd
12
app.go
12
app.go
|
@ -129,6 +129,8 @@ type App struct {
|
|||
handlersCount uint32
|
||||
// contains the information if the route stack has been changed to build the optimized tree
|
||||
routesRefreshed bool
|
||||
// state management
|
||||
state *State
|
||||
}
|
||||
|
||||
// Config is a struct holding the server settings.
|
||||
|
@ -138,7 +140,7 @@ type Config struct { //nolint:govet // Aligning the struct fields is not necessa
|
|||
// Default: ""
|
||||
ServerHeader string `json:"server_header"`
|
||||
|
||||
// When set to true, the router treats "/foo" and "/foo/" as different.
|
||||
// When set to Ftrue, the router treats "/foo" and "/foo/" as different.
|
||||
// By default this is disabled and both "/foo" and "/foo/" will execute the same handler.
|
||||
//
|
||||
// Default: false
|
||||
|
@ -515,6 +517,9 @@ func New(config ...Config) *App {
|
|||
// Define mountFields
|
||||
app.mountFields = newMountFields(app)
|
||||
|
||||
// Define state
|
||||
app.state = newState()
|
||||
|
||||
// Override config if provided
|
||||
if len(config) > 0 {
|
||||
app.config = config[0]
|
||||
|
@ -952,6 +957,11 @@ func (app *App) Hooks() *Hooks {
|
|||
return app.hooks
|
||||
}
|
||||
|
||||
// State returns the state struct to store global data in order to share it between handlers.
|
||||
func (app *App) State() *State {
|
||||
return app.state
|
||||
}
|
||||
|
||||
var ErrTestGotEmptyResponse = errors.New("test: got empty response")
|
||||
|
||||
// TestConfig is a struct holding Test settings
|
||||
|
|
10
app_test.go
10
app_test.go
|
@ -1890,6 +1890,16 @@ func Test_Route_Naming_Issue_2671_2685(t *testing.T) {
|
|||
require.Equal(t, "/simple-route", sRoute2.Path)
|
||||
}
|
||||
|
||||
func Test_App_State(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := New()
|
||||
|
||||
app.State().Set("key", "value")
|
||||
str, ok := app.State().GetString("key")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "value", str)
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Communication_Flow -benchmem -count=4
|
||||
func Benchmark_Communication_Flow(b *testing.B) {
|
||||
app := New()
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
package fiber
|
||||
|
||||
import "sync"
|
||||
|
||||
// State is a key-value store for Fiber's app in order to be used as a global storage for the app's dependencies.
|
||||
// It's a thread-safe implementation of a map[string]any, using sync.Map.
|
||||
type State struct {
|
||||
dependencies sync.Map
|
||||
}
|
||||
|
||||
// NewState creates a new instance of State.
|
||||
func newState() *State {
|
||||
return &State{
|
||||
dependencies: sync.Map{},
|
||||
}
|
||||
}
|
||||
|
||||
// Set sets a key-value pair in the State.
|
||||
func (s *State) Set(key string, value any) {
|
||||
s.dependencies.Store(key, value)
|
||||
}
|
||||
|
||||
// Get retrieves a value from the State.
|
||||
func (s *State) Get(key string) (any, bool) {
|
||||
return s.dependencies.Load(key)
|
||||
}
|
||||
|
||||
// GetString retrieves a string value from the State.
|
||||
func (s *State) GetString(key string) (string, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
depString, okCast := dep.(string)
|
||||
return depString, okCast
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
// GetInt retrieves an int value from the State.
|
||||
func (s *State) GetInt(key string) (int, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
depInt, okCast := dep.(int)
|
||||
return depInt, okCast
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetBool retrieves a bool value from the State.
|
||||
func (s *State) GetBool(key string) (bool, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
depBool, okCast := dep.(bool)
|
||||
return depBool, okCast
|
||||
}
|
||||
|
||||
return false, false
|
||||
}
|
||||
|
||||
// GetFloat64 retrieves a float64 value from the State.
|
||||
func (s *State) GetFloat64(key string) (float64, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
if ok {
|
||||
depFloat64, okCast := dep.(float64)
|
||||
return depFloat64, okCast
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// MustGet retrieves a value from the State and panics if the key is not found.
|
||||
func (s *State) MustGet(key string) any {
|
||||
if dep, ok := s.Get(key); ok {
|
||||
return dep
|
||||
}
|
||||
|
||||
panic("state: dependency not found!")
|
||||
}
|
||||
|
||||
// MustGetString retrieves a string value from the State and panics if the key is not found.
|
||||
func (s *State) Delete(key string) {
|
||||
s.dependencies.Delete(key)
|
||||
}
|
||||
|
||||
// Reset resets the State.
|
||||
func (s *State) Clear() {
|
||||
s.dependencies.Clear()
|
||||
}
|
||||
|
||||
// Keys retrieves all the keys from the State.
|
||||
func (s *State) Keys() []string {
|
||||
keys := make([]string, 0)
|
||||
s.dependencies.Range(func(key, _ any) bool {
|
||||
keys = append(keys, key.(string))
|
||||
return true
|
||||
})
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
// Len retrieves the number of dependencies in the State.
|
||||
func (s *State) Len() int {
|
||||
length := 0
|
||||
s.dependencies.Range(func(_, _ any) bool {
|
||||
length++
|
||||
return true
|
||||
})
|
||||
|
||||
return length
|
||||
}
|
||||
|
||||
// GetState retrieves a value from the State and casts it to the desired type.
|
||||
func GetState[T any](s *State, key string) (T, bool) {
|
||||
dep, ok := s.Get(key)
|
||||
|
||||
if ok {
|
||||
depT, okCast := dep.(T)
|
||||
return depT, okCast
|
||||
}
|
||||
|
||||
var zeroVal T
|
||||
return zeroVal, false
|
||||
}
|
||||
|
||||
// MustGetState retrieves a value from the State and casts it to the desired type, panicking if the key is not found.
|
||||
func MustGetState[T any](s *State, key string) T {
|
||||
dep, ok := GetState[T](s, key)
|
||||
if !ok {
|
||||
panic("state: dependency not found!")
|
||||
}
|
||||
|
||||
return dep
|
||||
}
|
|
@ -0,0 +1,423 @@
|
|||
package fiber
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestState_SetAndGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
// test setting and getting a value
|
||||
st.Set("foo", "bar")
|
||||
val, ok := st.Get("foo")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "bar", val)
|
||||
|
||||
// test key not found
|
||||
_, ok = st.Get("unknown")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestState_GetString(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("str", "hello")
|
||||
s, ok := st.GetString("str")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "hello", s)
|
||||
|
||||
// wrong type should return false
|
||||
st.Set("num", 123)
|
||||
s, ok = st.GetString("num")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, "", s)
|
||||
}
|
||||
|
||||
func TestState_GetInt(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("num", 456)
|
||||
i, ok := st.GetInt("num")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, 456, i)
|
||||
|
||||
// wrong type should return zero value
|
||||
st.Set("str", "abc")
|
||||
i, ok = st.GetInt("str")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, 0, i)
|
||||
}
|
||||
|
||||
func TestState_GetBool(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("flag", true)
|
||||
b, ok := st.GetBool("flag")
|
||||
require.True(t, ok)
|
||||
require.True(t, b)
|
||||
|
||||
// wrong type
|
||||
st.Set("num", 1)
|
||||
b, ok = st.GetBool("num")
|
||||
require.False(t, ok)
|
||||
require.False(t, b)
|
||||
}
|
||||
|
||||
func TestState_GetFloat64(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("pi", 3.14)
|
||||
f, ok := st.GetFloat64("pi")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, 3.14, f)
|
||||
|
||||
// wrong type should return zero value
|
||||
st.Set("int", 10)
|
||||
f, ok = st.GetFloat64("int")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, 0.0, f)
|
||||
}
|
||||
|
||||
func TestState_MustGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("exists", "value")
|
||||
val := st.MustGet("exists")
|
||||
require.Equal(t, "value", val)
|
||||
|
||||
// must-get on missing key should panic
|
||||
require.Panics(t, func() {
|
||||
_ = st.MustGet("missing")
|
||||
})
|
||||
}
|
||||
|
||||
func TestState_Delete(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("key", "value")
|
||||
st.Delete("key")
|
||||
_, ok := st.Get("key")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestState_Clear(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("a", 1)
|
||||
st.Set("b", 2)
|
||||
st.Clear()
|
||||
require.Equal(t, 0, st.Len())
|
||||
require.Empty(t, st.Keys())
|
||||
}
|
||||
|
||||
func TestState_Keys(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
keys := []string{"one", "two", "three"}
|
||||
for _, k := range keys {
|
||||
st.Set(k, k)
|
||||
}
|
||||
|
||||
returnedKeys := st.Keys()
|
||||
require.ElementsMatch(t, keys, returnedKeys)
|
||||
}
|
||||
|
||||
func TestState_Len(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
require.Equal(t, 0, st.Len())
|
||||
|
||||
st.Set("a", "a")
|
||||
require.Equal(t, 1, st.Len())
|
||||
|
||||
st.Set("b", "b")
|
||||
require.Equal(t, 2, st.Len())
|
||||
|
||||
st.Delete("a")
|
||||
require.Equal(t, 1, st.Len())
|
||||
}
|
||||
|
||||
type testCase[T any] struct {
|
||||
name string
|
||||
key string
|
||||
value any
|
||||
expected T
|
||||
ok bool
|
||||
}
|
||||
|
||||
func runGenericTest[T any](t *testing.T, getter func(*State, string) (T, bool), tests []testCase[T]) {
|
||||
st := newState()
|
||||
for _, tc := range tests {
|
||||
st.Set(tc.key, tc.value)
|
||||
got, ok := getter(st, tc.key)
|
||||
require.Equal(t, tc.ok, ok, tc.name)
|
||||
require.Equal(t, tc.expected, got, tc.name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestState_GetGeneric(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runGenericTest[int](t, GetState[int], []testCase[int]{
|
||||
{"int correct conversion", "num", 42, 42, true},
|
||||
{"int wrong conversion from string", "str", "abc", 0, false},
|
||||
})
|
||||
|
||||
runGenericTest[string](t, GetState[string], []testCase[string]{
|
||||
{"string correct conversion", "strVal", "hello", "hello", true},
|
||||
{"string wrong conversion from int", "intVal", 100, "", false},
|
||||
})
|
||||
|
||||
runGenericTest[bool](t, GetState[bool], []testCase[bool]{
|
||||
{"bool correct conversion", "flag", true, true, true},
|
||||
{"bool wrong conversion from int", "intFlag", 1, false, false},
|
||||
})
|
||||
|
||||
runGenericTest[float64](t, GetState[float64], []testCase[float64]{
|
||||
{"float64 correct conversion", "pi", 3.14, 3.14, true},
|
||||
{"float64 wrong conversion from int", "intVal", 10, 0.0, false},
|
||||
})
|
||||
}
|
||||
|
||||
func Test_MustGetStateGeneric(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newState()
|
||||
|
||||
st.Set("flag", true)
|
||||
flag := MustGetState[bool](st, "flag")
|
||||
require.True(t, flag)
|
||||
|
||||
// mismatched type should panic
|
||||
require.Panics(t, func() {
|
||||
_ = MustGetState[string](st, "flag")
|
||||
})
|
||||
|
||||
// missing key should also panic
|
||||
require.Panics(t, func() {
|
||||
_ = MustGetState[string](st, "missing")
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkState_Set(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, i)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_Get(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
// pre-populate the state
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.Get(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetString(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
// pre-populate the state
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, strconv.Itoa(i))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetString(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetInt(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
// pre-populate the state
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetInt(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetBool(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
// pre-populate the state
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, i%2 == 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetBool(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetFloat64(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
// pre-populate the state
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, float64(i))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.GetFloat64(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_MustGet(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
// pre-populate the state
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
st.MustGet(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_GetStateGeneric(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
// pre-populate the state
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
GetState[int](st, key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_MustGetStateGeneric(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
// pre-populate the state
|
||||
for i := 0; i < n; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
st.Set(key, i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "key" + strconv.Itoa(i%n)
|
||||
MustGetState[int](st, key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_Delete(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
st := newState()
|
||||
st.Set("a", 1)
|
||||
st.Delete("a")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_Clear(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
st := newState()
|
||||
// add a fixed number of keys before clearing
|
||||
for j := 0; j < 100; j++ {
|
||||
st.Set("key"+strconv.Itoa(j), j)
|
||||
}
|
||||
st.Clear()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_Keys(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
for i := 0; i < n; i++ {
|
||||
st.Set("key"+strconv.Itoa(i), i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = st.Keys()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkState_Len(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
st := newState()
|
||||
n := 1000
|
||||
for i := 0; i < n; i++ {
|
||||
st.Set("key"+strconv.Itoa(i), i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = st.Len()
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue