diff --git a/app.go b/app.go index 8d01f3f9..2a8a5754 100644 --- a/app.go +++ b/app.go @@ -106,6 +106,8 @@ type App struct { tlsHandler *TLSHandler // Mount fields mountFields *mountFields + // state management + state *State // Route stack divided by HTTP methods stack [][]*Route // Route stack divided by HTTP methods and route prefixes @@ -129,8 +131,6 @@ type App struct { handlersCount uint32 // contains the information if the route stack has been changed to build the optimized tree routesRefreshed bool - // state management - state *State } // Config is a struct holding the server settings. diff --git a/state.go b/state.go index 73c3aeb7..90da98ee 100644 --- a/state.go +++ b/state.go @@ -1,6 +1,8 @@ package fiber -import "sync" +import ( + "sync" +) // State is a key-value store for Fiber's app in order to be used as a global storage for the app's dependencies. // It's a thread-safe implementation of a map[string]any, using sync.Map. @@ -48,7 +50,7 @@ func (s *State) GetInt(key string) (int, bool) { } // GetBool retrieves a bool value from the State. -func (s *State) GetBool(key string) (bool, bool) { +func (s *State) GetBool(key string) (value, ok bool) { //nolint:nonamedreturns // Better idea to use named returns here dep, ok := s.Get(key) if ok { depBool, okCast := dep.(bool) @@ -92,7 +94,12 @@ func (s *State) Clear() { func (s *State) Keys() []string { keys := make([]string, 0) s.dependencies.Range(func(key, _ any) bool { - keys = append(keys, key.(string)) + keyStr, ok := key.(string) + if !ok { + return false + } + + keys = append(keys, keyStr) return true }) diff --git a/state_test.go b/state_test.go index d573f470..b09f01ac 100644 --- a/state_test.go +++ b/state_test.go @@ -92,18 +92,18 @@ func TestState_GetFloat64(t *testing.T) { st.Set("pi", 3.14) f, ok := st.GetFloat64("pi") require.True(t, ok) - require.Equal(t, 3.14, f) + require.InDelta(t, 3.14, f, 0.0001) // wrong type should return zero value st.Set("int", 10) f, ok = st.GetFloat64("int") require.False(t, ok) - require.Equal(t, 0.0, f) + require.InDelta(t, 0.0, f, 0.0001) // missing key should return zero value f, ok = st.GetFloat64("missing") require.False(t, ok) - require.Equal(t, 0.0, f) + require.InDelta(t, 0.0, f, 0.0001) } func TestState_MustGet(t *testing.T) { @@ -170,7 +170,7 @@ func TestState_Len(t *testing.T) { require.Equal(t, 1, st.Len()) } -type testCase[T any] struct { +type testCase[T any] struct { //nolint:govet // It does not really matter for test name string key string value any @@ -179,6 +179,8 @@ type testCase[T any] struct { } func runGenericTest[T any](t *testing.T, getter func(*State, string) (T, bool), tests []testCase[T]) { + t.Helper() + st := newState() for _, tc := range tests { st.Set(tc.key, tc.value)