diff --git a/app.go b/app.go index 84059e79..8d01f3f9 100644 --- a/app.go +++ b/app.go @@ -129,6 +129,8 @@ type App struct { handlersCount uint32 // contains the information if the route stack has been changed to build the optimized tree routesRefreshed bool + // state management + state *State } // Config is a struct holding the server settings. @@ -138,7 +140,7 @@ type Config struct { //nolint:govet // Aligning the struct fields is not necessa // Default: "" ServerHeader string `json:"server_header"` - // When set to true, the router treats "/foo" and "/foo/" as different. + // When set to Ftrue, the router treats "/foo" and "/foo/" as different. // By default this is disabled and both "/foo" and "/foo/" will execute the same handler. // // Default: false @@ -515,6 +517,9 @@ func New(config ...Config) *App { // Define mountFields app.mountFields = newMountFields(app) + // Define state + app.state = newState() + // Override config if provided if len(config) > 0 { app.config = config[0] @@ -952,6 +957,11 @@ func (app *App) Hooks() *Hooks { return app.hooks } +// State returns the state struct to store global data in order to share it between handlers. +func (app *App) State() *State { + return app.state +} + var ErrTestGotEmptyResponse = errors.New("test: got empty response") // TestConfig is a struct holding Test settings diff --git a/app_test.go b/app_test.go index 84606e6b..b5d5ed46 100644 --- a/app_test.go +++ b/app_test.go @@ -1890,6 +1890,16 @@ func Test_Route_Naming_Issue_2671_2685(t *testing.T) { require.Equal(t, "/simple-route", sRoute2.Path) } +func Test_App_State(t *testing.T) { + t.Parallel() + app := New() + + app.State().Set("key", "value") + str, ok := app.State().GetString("key") + require.True(t, ok) + require.Equal(t, "value", str) +} + // go test -v -run=^$ -bench=Benchmark_Communication_Flow -benchmem -count=4 func Benchmark_Communication_Flow(b *testing.B) { app := New() diff --git a/state.go b/state.go new file mode 100644 index 00000000..73c3aeb7 --- /dev/null +++ b/state.go @@ -0,0 +1,134 @@ +package fiber + +import "sync" + +// State is a key-value store for Fiber's app in order to be used as a global storage for the app's dependencies. +// It's a thread-safe implementation of a map[string]any, using sync.Map. +type State struct { + dependencies sync.Map +} + +// NewState creates a new instance of State. +func newState() *State { + return &State{ + dependencies: sync.Map{}, + } +} + +// Set sets a key-value pair in the State. +func (s *State) Set(key string, value any) { + s.dependencies.Store(key, value) +} + +// Get retrieves a value from the State. +func (s *State) Get(key string) (any, bool) { + return s.dependencies.Load(key) +} + +// GetString retrieves a string value from the State. +func (s *State) GetString(key string) (string, bool) { + dep, ok := s.Get(key) + if ok { + depString, okCast := dep.(string) + return depString, okCast + } + + return "", false +} + +// GetInt retrieves an int value from the State. +func (s *State) GetInt(key string) (int, bool) { + dep, ok := s.Get(key) + if ok { + depInt, okCast := dep.(int) + return depInt, okCast + } + + return 0, false +} + +// GetBool retrieves a bool value from the State. +func (s *State) GetBool(key string) (bool, bool) { + dep, ok := s.Get(key) + if ok { + depBool, okCast := dep.(bool) + return depBool, okCast + } + + return false, false +} + +// GetFloat64 retrieves a float64 value from the State. +func (s *State) GetFloat64(key string) (float64, bool) { + dep, ok := s.Get(key) + if ok { + depFloat64, okCast := dep.(float64) + return depFloat64, okCast + } + + return 0, false +} + +// MustGet retrieves a value from the State and panics if the key is not found. +func (s *State) MustGet(key string) any { + if dep, ok := s.Get(key); ok { + return dep + } + + panic("state: dependency not found!") +} + +// MustGetString retrieves a string value from the State and panics if the key is not found. +func (s *State) Delete(key string) { + s.dependencies.Delete(key) +} + +// Reset resets the State. +func (s *State) Clear() { + s.dependencies.Clear() +} + +// Keys retrieves all the keys from the State. +func (s *State) Keys() []string { + keys := make([]string, 0) + s.dependencies.Range(func(key, _ any) bool { + keys = append(keys, key.(string)) + return true + }) + + return keys +} + +// Len retrieves the number of dependencies in the State. +func (s *State) Len() int { + length := 0 + s.dependencies.Range(func(_, _ any) bool { + length++ + return true + }) + + return length +} + +// GetState retrieves a value from the State and casts it to the desired type. +func GetState[T any](s *State, key string) (T, bool) { + dep, ok := s.Get(key) + + if ok { + depT, okCast := dep.(T) + return depT, okCast + } + + var zeroVal T + return zeroVal, false +} + +// MustGetState retrieves a value from the State and casts it to the desired type, panicking if the key is not found. +func MustGetState[T any](s *State, key string) T { + dep, ok := GetState[T](s, key) + if !ok { + panic("state: dependency not found!") + } + + return dep +} diff --git a/state_test.go b/state_test.go new file mode 100644 index 00000000..d5f766a1 --- /dev/null +++ b/state_test.go @@ -0,0 +1,423 @@ +package fiber + +import ( + "strconv" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestState_SetAndGet(t *testing.T) { + t.Parallel() + st := newState() + + // test setting and getting a value + st.Set("foo", "bar") + val, ok := st.Get("foo") + require.True(t, ok) + require.Equal(t, "bar", val) + + // test key not found + _, ok = st.Get("unknown") + require.False(t, ok) +} + +func TestState_GetString(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("str", "hello") + s, ok := st.GetString("str") + require.True(t, ok) + require.Equal(t, "hello", s) + + // wrong type should return false + st.Set("num", 123) + s, ok = st.GetString("num") + require.False(t, ok) + require.Equal(t, "", s) +} + +func TestState_GetInt(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("num", 456) + i, ok := st.GetInt("num") + require.True(t, ok) + require.Equal(t, 456, i) + + // wrong type should return zero value + st.Set("str", "abc") + i, ok = st.GetInt("str") + require.False(t, ok) + require.Equal(t, 0, i) +} + +func TestState_GetBool(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("flag", true) + b, ok := st.GetBool("flag") + require.True(t, ok) + require.True(t, b) + + // wrong type + st.Set("num", 1) + b, ok = st.GetBool("num") + require.False(t, ok) + require.False(t, b) +} + +func TestState_GetFloat64(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("pi", 3.14) + f, ok := st.GetFloat64("pi") + require.True(t, ok) + require.Equal(t, 3.14, f) + + // wrong type should return zero value + st.Set("int", 10) + f, ok = st.GetFloat64("int") + require.False(t, ok) + require.Equal(t, 0.0, f) +} + +func TestState_MustGet(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("exists", "value") + val := st.MustGet("exists") + require.Equal(t, "value", val) + + // must-get on missing key should panic + require.Panics(t, func() { + _ = st.MustGet("missing") + }) +} + +func TestState_Delete(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("key", "value") + st.Delete("key") + _, ok := st.Get("key") + require.False(t, ok) +} + +func TestState_Clear(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("a", 1) + st.Set("b", 2) + st.Clear() + require.Equal(t, 0, st.Len()) + require.Empty(t, st.Keys()) +} + +func TestState_Keys(t *testing.T) { + t.Parallel() + st := newState() + + keys := []string{"one", "two", "three"} + for _, k := range keys { + st.Set(k, k) + } + + returnedKeys := st.Keys() + require.ElementsMatch(t, keys, returnedKeys) +} + +func TestState_Len(t *testing.T) { + t.Parallel() + st := newState() + + require.Equal(t, 0, st.Len()) + + st.Set("a", "a") + require.Equal(t, 1, st.Len()) + + st.Set("b", "b") + require.Equal(t, 2, st.Len()) + + st.Delete("a") + require.Equal(t, 1, st.Len()) +} + +type testCase[T any] struct { + name string + key string + value any + expected T + ok bool +} + +func runGenericTest[T any](t *testing.T, getter func(*State, string) (T, bool), tests []testCase[T]) { + st := newState() + for _, tc := range tests { + st.Set(tc.key, tc.value) + got, ok := getter(st, tc.key) + require.Equal(t, tc.ok, ok, tc.name) + require.Equal(t, tc.expected, got, tc.name) + } +} + +func TestState_GetGeneric(t *testing.T) { + t.Parallel() + + runGenericTest[int](t, GetState[int], []testCase[int]{ + {"int correct conversion", "num", 42, 42, true}, + {"int wrong conversion from string", "str", "abc", 0, false}, + }) + + runGenericTest[string](t, GetState[string], []testCase[string]{ + {"string correct conversion", "strVal", "hello", "hello", true}, + {"string wrong conversion from int", "intVal", 100, "", false}, + }) + + runGenericTest[bool](t, GetState[bool], []testCase[bool]{ + {"bool correct conversion", "flag", true, true, true}, + {"bool wrong conversion from int", "intFlag", 1, false, false}, + }) + + runGenericTest[float64](t, GetState[float64], []testCase[float64]{ + {"float64 correct conversion", "pi", 3.14, 3.14, true}, + {"float64 wrong conversion from int", "intVal", 10, 0.0, false}, + }) +} + +func Test_MustGetStateGeneric(t *testing.T) { + t.Parallel() + st := newState() + + st.Set("flag", true) + flag := MustGetState[bool](st, "flag") + require.True(t, flag) + + // mismatched type should panic + require.Panics(t, func() { + _ = MustGetState[string](st, "flag") + }) + + // missing key should also panic + require.Panics(t, func() { + _ = MustGetState[string](st, "missing") + }) +} + +func BenchmarkState_Set(b *testing.B) { + b.ReportAllocs() + + st := newState() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, i) + } +} + +func BenchmarkState_Get(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + // pre-populate the state + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.Get(key) + } +} + +func BenchmarkState_GetString(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + // pre-populate the state + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, strconv.Itoa(i)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetString(key) + } +} + +func BenchmarkState_GetInt(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + // pre-populate the state + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetInt(key) + } +} + +func BenchmarkState_GetBool(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + // pre-populate the state + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, i%2 == 0) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetBool(key) + } +} + +func BenchmarkState_GetFloat64(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + // pre-populate the state + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, float64(i)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.GetFloat64(key) + } +} + +func BenchmarkState_MustGet(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + // pre-populate the state + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + st.MustGet(key) + } +} + +func BenchmarkState_GetStateGeneric(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + // pre-populate the state + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + GetState[int](st, key) + } +} + +func BenchmarkState_MustGetStateGeneric(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + // pre-populate the state + for i := 0; i < n; i++ { + key := "key" + strconv.Itoa(i) + st.Set(key, i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + strconv.Itoa(i%n) + MustGetState[int](st, key) + } +} + +func BenchmarkState_Delete(b *testing.B) { + b.ReportAllocs() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + st := newState() + st.Set("a", 1) + st.Delete("a") + } +} + +func BenchmarkState_Clear(b *testing.B) { + b.ReportAllocs() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + st := newState() + // add a fixed number of keys before clearing + for j := 0; j < 100; j++ { + st.Set("key"+strconv.Itoa(j), j) + } + st.Clear() + } +} + +func BenchmarkState_Keys(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + for i := 0; i < n; i++ { + st.Set("key"+strconv.Itoa(i), i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = st.Keys() + } +} + +func BenchmarkState_Len(b *testing.B) { + b.ReportAllocs() + + st := newState() + n := 1000 + for i := 0; i < n; i++ { + st.Set("key"+strconv.Itoa(i), i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = st.Len() + } +}