feat: add support for application state management

state-management
Muhammed Efe Cetin 2024-10-07 23:44:33 +03:00
parent 395c8fafa9
commit 655677e8bd
No known key found for this signature in database
GPG Key ID: 0AA4D45CBAA86F73
4 changed files with 578 additions and 1 deletions

12
app.go
View File

@ -129,6 +129,8 @@ type App struct {
handlersCount uint32
// contains the information if the route stack has been changed to build the optimized tree
routesRefreshed bool
// state management
state *State
}
// Config is a struct holding the server settings.
@ -138,7 +140,7 @@ type Config struct { //nolint:govet // Aligning the struct fields is not necessa
// Default: ""
ServerHeader string `json:"server_header"`
// When set to true, the router treats "/foo" and "/foo/" as different.
// When set to Ftrue, the router treats "/foo" and "/foo/" as different.
// By default this is disabled and both "/foo" and "/foo/" will execute the same handler.
//
// Default: false
@ -515,6 +517,9 @@ func New(config ...Config) *App {
// Define mountFields
app.mountFields = newMountFields(app)
// Define state
app.state = newState()
// Override config if provided
if len(config) > 0 {
app.config = config[0]
@ -952,6 +957,11 @@ func (app *App) Hooks() *Hooks {
return app.hooks
}
// State returns the state struct to store global data in order to share it between handlers.
func (app *App) State() *State {
return app.state
}
var ErrTestGotEmptyResponse = errors.New("test: got empty response")
// TestConfig is a struct holding Test settings

View File

@ -1890,6 +1890,16 @@ func Test_Route_Naming_Issue_2671_2685(t *testing.T) {
require.Equal(t, "/simple-route", sRoute2.Path)
}
func Test_App_State(t *testing.T) {
t.Parallel()
app := New()
app.State().Set("key", "value")
str, ok := app.State().GetString("key")
require.True(t, ok)
require.Equal(t, "value", str)
}
// go test -v -run=^$ -bench=Benchmark_Communication_Flow -benchmem -count=4
func Benchmark_Communication_Flow(b *testing.B) {
app := New()

134
state.go Normal file
View File

@ -0,0 +1,134 @@
package fiber
import "sync"
// State is a key-value store for Fiber's app in order to be used as a global storage for the app's dependencies.
// It's a thread-safe implementation of a map[string]any, using sync.Map.
type State struct {
dependencies sync.Map
}
// NewState creates a new instance of State.
func newState() *State {
return &State{
dependencies: sync.Map{},
}
}
// Set sets a key-value pair in the State.
func (s *State) Set(key string, value any) {
s.dependencies.Store(key, value)
}
// Get retrieves a value from the State.
func (s *State) Get(key string) (any, bool) {
return s.dependencies.Load(key)
}
// GetString retrieves a string value from the State.
func (s *State) GetString(key string) (string, bool) {
dep, ok := s.Get(key)
if ok {
depString, okCast := dep.(string)
return depString, okCast
}
return "", false
}
// GetInt retrieves an int value from the State.
func (s *State) GetInt(key string) (int, bool) {
dep, ok := s.Get(key)
if ok {
depInt, okCast := dep.(int)
return depInt, okCast
}
return 0, false
}
// GetBool retrieves a bool value from the State.
func (s *State) GetBool(key string) (bool, bool) {
dep, ok := s.Get(key)
if ok {
depBool, okCast := dep.(bool)
return depBool, okCast
}
return false, false
}
// GetFloat64 retrieves a float64 value from the State.
func (s *State) GetFloat64(key string) (float64, bool) {
dep, ok := s.Get(key)
if ok {
depFloat64, okCast := dep.(float64)
return depFloat64, okCast
}
return 0, false
}
// MustGet retrieves a value from the State and panics if the key is not found.
func (s *State) MustGet(key string) any {
if dep, ok := s.Get(key); ok {
return dep
}
panic("state: dependency not found!")
}
// MustGetString retrieves a string value from the State and panics if the key is not found.
func (s *State) Delete(key string) {
s.dependencies.Delete(key)
}
// Reset resets the State.
func (s *State) Clear() {
s.dependencies.Clear()
}
// Keys retrieves all the keys from the State.
func (s *State) Keys() []string {
keys := make([]string, 0)
s.dependencies.Range(func(key, _ any) bool {
keys = append(keys, key.(string))
return true
})
return keys
}
// Len retrieves the number of dependencies in the State.
func (s *State) Len() int {
length := 0
s.dependencies.Range(func(_, _ any) bool {
length++
return true
})
return length
}
// GetState retrieves a value from the State and casts it to the desired type.
func GetState[T any](s *State, key string) (T, bool) {
dep, ok := s.Get(key)
if ok {
depT, okCast := dep.(T)
return depT, okCast
}
var zeroVal T
return zeroVal, false
}
// MustGetState retrieves a value from the State and casts it to the desired type, panicking if the key is not found.
func MustGetState[T any](s *State, key string) T {
dep, ok := GetState[T](s, key)
if !ok {
panic("state: dependency not found!")
}
return dep
}

423
state_test.go Normal file
View File

@ -0,0 +1,423 @@
package fiber
import (
"strconv"
"testing"
"github.com/stretchr/testify/require"
)
func TestState_SetAndGet(t *testing.T) {
t.Parallel()
st := newState()
// test setting and getting a value
st.Set("foo", "bar")
val, ok := st.Get("foo")
require.True(t, ok)
require.Equal(t, "bar", val)
// test key not found
_, ok = st.Get("unknown")
require.False(t, ok)
}
func TestState_GetString(t *testing.T) {
t.Parallel()
st := newState()
st.Set("str", "hello")
s, ok := st.GetString("str")
require.True(t, ok)
require.Equal(t, "hello", s)
// wrong type should return false
st.Set("num", 123)
s, ok = st.GetString("num")
require.False(t, ok)
require.Equal(t, "", s)
}
func TestState_GetInt(t *testing.T) {
t.Parallel()
st := newState()
st.Set("num", 456)
i, ok := st.GetInt("num")
require.True(t, ok)
require.Equal(t, 456, i)
// wrong type should return zero value
st.Set("str", "abc")
i, ok = st.GetInt("str")
require.False(t, ok)
require.Equal(t, 0, i)
}
func TestState_GetBool(t *testing.T) {
t.Parallel()
st := newState()
st.Set("flag", true)
b, ok := st.GetBool("flag")
require.True(t, ok)
require.True(t, b)
// wrong type
st.Set("num", 1)
b, ok = st.GetBool("num")
require.False(t, ok)
require.False(t, b)
}
func TestState_GetFloat64(t *testing.T) {
t.Parallel()
st := newState()
st.Set("pi", 3.14)
f, ok := st.GetFloat64("pi")
require.True(t, ok)
require.Equal(t, 3.14, f)
// wrong type should return zero value
st.Set("int", 10)
f, ok = st.GetFloat64("int")
require.False(t, ok)
require.Equal(t, 0.0, f)
}
func TestState_MustGet(t *testing.T) {
t.Parallel()
st := newState()
st.Set("exists", "value")
val := st.MustGet("exists")
require.Equal(t, "value", val)
// must-get on missing key should panic
require.Panics(t, func() {
_ = st.MustGet("missing")
})
}
func TestState_Delete(t *testing.T) {
t.Parallel()
st := newState()
st.Set("key", "value")
st.Delete("key")
_, ok := st.Get("key")
require.False(t, ok)
}
func TestState_Clear(t *testing.T) {
t.Parallel()
st := newState()
st.Set("a", 1)
st.Set("b", 2)
st.Clear()
require.Equal(t, 0, st.Len())
require.Empty(t, st.Keys())
}
func TestState_Keys(t *testing.T) {
t.Parallel()
st := newState()
keys := []string{"one", "two", "three"}
for _, k := range keys {
st.Set(k, k)
}
returnedKeys := st.Keys()
require.ElementsMatch(t, keys, returnedKeys)
}
func TestState_Len(t *testing.T) {
t.Parallel()
st := newState()
require.Equal(t, 0, st.Len())
st.Set("a", "a")
require.Equal(t, 1, st.Len())
st.Set("b", "b")
require.Equal(t, 2, st.Len())
st.Delete("a")
require.Equal(t, 1, st.Len())
}
type testCase[T any] struct {
name string
key string
value any
expected T
ok bool
}
func runGenericTest[T any](t *testing.T, getter func(*State, string) (T, bool), tests []testCase[T]) {
st := newState()
for _, tc := range tests {
st.Set(tc.key, tc.value)
got, ok := getter(st, tc.key)
require.Equal(t, tc.ok, ok, tc.name)
require.Equal(t, tc.expected, got, tc.name)
}
}
func TestState_GetGeneric(t *testing.T) {
t.Parallel()
runGenericTest[int](t, GetState[int], []testCase[int]{
{"int correct conversion", "num", 42, 42, true},
{"int wrong conversion from string", "str", "abc", 0, false},
})
runGenericTest[string](t, GetState[string], []testCase[string]{
{"string correct conversion", "strVal", "hello", "hello", true},
{"string wrong conversion from int", "intVal", 100, "", false},
})
runGenericTest[bool](t, GetState[bool], []testCase[bool]{
{"bool correct conversion", "flag", true, true, true},
{"bool wrong conversion from int", "intFlag", 1, false, false},
})
runGenericTest[float64](t, GetState[float64], []testCase[float64]{
{"float64 correct conversion", "pi", 3.14, 3.14, true},
{"float64 wrong conversion from int", "intVal", 10, 0.0, false},
})
}
func Test_MustGetStateGeneric(t *testing.T) {
t.Parallel()
st := newState()
st.Set("flag", true)
flag := MustGetState[bool](st, "flag")
require.True(t, flag)
// mismatched type should panic
require.Panics(t, func() {
_ = MustGetState[string](st, "flag")
})
// missing key should also panic
require.Panics(t, func() {
_ = MustGetState[string](st, "missing")
})
}
func BenchmarkState_Set(b *testing.B) {
b.ReportAllocs()
st := newState()
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, i)
}
}
func BenchmarkState_Get(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// pre-populate the state
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.Get(key)
}
}
func BenchmarkState_GetString(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// pre-populate the state
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, strconv.Itoa(i))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetString(key)
}
}
func BenchmarkState_GetInt(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// pre-populate the state
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetInt(key)
}
}
func BenchmarkState_GetBool(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// pre-populate the state
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, i%2 == 0)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetBool(key)
}
}
func BenchmarkState_GetFloat64(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// pre-populate the state
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, float64(i))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetFloat64(key)
}
}
func BenchmarkState_MustGet(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// pre-populate the state
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.MustGet(key)
}
}
func BenchmarkState_GetStateGeneric(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// pre-populate the state
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
GetState[int](st, key)
}
}
func BenchmarkState_MustGetStateGeneric(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// pre-populate the state
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
MustGetState[int](st, key)
}
}
func BenchmarkState_Delete(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
st := newState()
st.Set("a", 1)
st.Delete("a")
}
}
func BenchmarkState_Clear(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
st := newState()
// add a fixed number of keys before clearing
for j := 0; j < 100; j++ {
st.Set("key"+strconv.Itoa(j), j)
}
st.Clear()
}
}
func BenchmarkState_Keys(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
for i := 0; i < n; i++ {
st.Set("key"+strconv.Itoa(i), i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = st.Keys()
}
}
func BenchmarkState_Len(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
for i := 0; i < n; i++ {
st.Set("key"+strconv.Itoa(i), i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = st.Len()
}
}