fiber/state_test.go

982 lines
20 KiB
Go

package fiber
import (
"strconv"
"testing"
"github.com/stretchr/testify/require"
)
func TestState_SetAndGet_WithApp(t *testing.T) {
t.Parallel()
// Create app
app := New()
// test setting and getting a value
app.State().Set("foo", "bar")
val, ok := app.State().Get("foo")
require.True(t, ok)
require.Equal(t, "bar", val)
// test key not found
_, ok = app.State().Get("unknown")
require.False(t, ok)
}
func TestState_SetAndGet(t *testing.T) {
t.Parallel()
st := newState()
// test setting and getting a value
st.Set("foo", "bar")
val, ok := st.Get("foo")
require.True(t, ok)
require.Equal(t, "bar", val)
// test key not found
_, ok = st.Get("unknown")
require.False(t, ok)
}
func TestState_GetString(t *testing.T) {
t.Parallel()
st := newState()
st.Set("str", "hello")
s, ok := st.GetString("str")
require.True(t, ok)
require.Equal(t, "hello", s)
// wrong type should return false
st.Set("num", 123)
s, ok = st.GetString("num")
require.False(t, ok)
require.Equal(t, "", s)
// missing key should return false
s, ok = st.GetString("missing")
require.False(t, ok)
require.Equal(t, "", s)
}
func TestState_GetInt(t *testing.T) {
t.Parallel()
st := newState()
st.Set("num", 456)
i, ok := st.GetInt("num")
require.True(t, ok)
require.Equal(t, 456, i)
// wrong type should return zero value
st.Set("str", "abc")
i, ok = st.GetInt("str")
require.False(t, ok)
require.Equal(t, 0, i)
// missing key should return zero value
i, ok = st.GetInt("missing")
require.False(t, ok)
require.Equal(t, 0, i)
}
func TestState_GetBool(t *testing.T) {
t.Parallel()
st := newState()
st.Set("flag", true)
b, ok := st.GetBool("flag")
require.True(t, ok)
require.True(t, b)
// wrong type
st.Set("num", 1)
b, ok = st.GetBool("num")
require.False(t, ok)
require.False(t, b)
// missing key should return false
b, ok = st.GetBool("missing")
require.False(t, ok)
require.False(t, b)
}
func TestState_GetFloat64(t *testing.T) {
t.Parallel()
st := newState()
st.Set("pi", 3.14)
f, ok := st.GetFloat64("pi")
require.True(t, ok)
require.InDelta(t, 3.14, f, 0.0001)
// wrong type should return zero value
st.Set("int", 10)
f, ok = st.GetFloat64("int")
require.False(t, ok)
require.InDelta(t, 0.0, f, 0.0001)
// missing key should return zero value
f, ok = st.GetFloat64("missing")
require.False(t, ok)
require.InDelta(t, 0.0, f, 0.0001)
}
func TestState_GetUint(t *testing.T) {
t.Parallel()
st := newState()
st.Set("uint", uint(100))
u, ok := st.GetUint("uint")
require.True(t, ok)
require.Equal(t, uint(100), u)
st.Set("wrong", "not uint")
u, ok = st.GetUint("wrong")
require.False(t, ok)
require.Equal(t, uint(0), u)
u, ok = st.GetUint("missing")
require.False(t, ok)
require.Equal(t, uint(0), u)
}
func TestState_GetInt8(t *testing.T) {
t.Parallel()
st := newState()
st.Set("int8", int8(10))
i, ok := st.GetInt8("int8")
require.True(t, ok)
require.Equal(t, int8(10), i)
st.Set("wrong", "not int8")
i, ok = st.GetInt8("wrong")
require.False(t, ok)
require.Equal(t, int8(0), i)
i, ok = st.GetInt8("missing")
require.False(t, ok)
require.Equal(t, int8(0), i)
}
func TestState_GetInt16(t *testing.T) {
t.Parallel()
st := newState()
st.Set("int16", int16(200))
i, ok := st.GetInt16("int16")
require.True(t, ok)
require.Equal(t, int16(200), i)
st.Set("wrong", "not int16")
i, ok = st.GetInt16("wrong")
require.False(t, ok)
require.Equal(t, int16(0), i)
i, ok = st.GetInt16("missing")
require.False(t, ok)
require.Equal(t, int16(0), i)
}
func TestState_GetInt32(t *testing.T) {
t.Parallel()
st := newState()
st.Set("int32", int32(3000))
i, ok := st.GetInt32("int32")
require.True(t, ok)
require.Equal(t, int32(3000), i)
st.Set("wrong", "not int32")
i, ok = st.GetInt32("wrong")
require.False(t, ok)
require.Equal(t, int32(0), i)
i, ok = st.GetInt32("missing")
require.False(t, ok)
require.Equal(t, int32(0), i)
}
func TestState_GetInt64(t *testing.T) {
t.Parallel()
st := newState()
st.Set("int64", int64(4000))
i, ok := st.GetInt64("int64")
require.True(t, ok)
require.Equal(t, int64(4000), i)
st.Set("wrong", "not int64")
i, ok = st.GetInt64("wrong")
require.False(t, ok)
require.Equal(t, int64(0), i)
i, ok = st.GetInt64("missing")
require.False(t, ok)
require.Equal(t, int64(0), i)
}
func TestState_GetUint8(t *testing.T) {
t.Parallel()
st := newState()
st.Set("uint8", uint8(20))
u, ok := st.GetUint8("uint8")
require.True(t, ok)
require.Equal(t, uint8(20), u)
st.Set("wrong", "not uint8")
u, ok = st.GetUint8("wrong")
require.False(t, ok)
require.Equal(t, uint8(0), u)
u, ok = st.GetUint8("missing")
require.False(t, ok)
require.Equal(t, uint8(0), u)
}
func TestState_GetUint16(t *testing.T) {
t.Parallel()
st := newState()
st.Set("uint16", uint16(300))
u, ok := st.GetUint16("uint16")
require.True(t, ok)
require.Equal(t, uint16(300), u)
st.Set("wrong", "not uint16")
u, ok = st.GetUint16("wrong")
require.False(t, ok)
require.Equal(t, uint16(0), u)
u, ok = st.GetUint16("missing")
require.False(t, ok)
require.Equal(t, uint16(0), u)
}
func TestState_GetUint32(t *testing.T) {
t.Parallel()
st := newState()
st.Set("uint32", uint32(400000))
u, ok := st.GetUint32("uint32")
require.True(t, ok)
require.Equal(t, uint32(400000), u)
st.Set("wrong", "not uint32")
u, ok = st.GetUint32("wrong")
require.False(t, ok)
require.Equal(t, uint32(0), u)
u, ok = st.GetUint32("missing")
require.False(t, ok)
require.Equal(t, uint32(0), u)
}
func TestState_GetUint64(t *testing.T) {
t.Parallel()
st := newState()
st.Set("uint64", uint64(5000000))
u, ok := st.GetUint64("uint64")
require.True(t, ok)
require.Equal(t, uint64(5000000), u)
st.Set("wrong", "not uint64")
u, ok = st.GetUint64("wrong")
require.False(t, ok)
require.Equal(t, uint64(0), u)
u, ok = st.GetUint64("missing")
require.False(t, ok)
require.Equal(t, uint64(0), u)
}
func TestState_GetUintptr(t *testing.T) {
t.Parallel()
st := newState()
var ptr uintptr = 12345
st.Set("uintptr", ptr)
u, ok := st.GetUintptr("uintptr")
require.True(t, ok)
require.Equal(t, ptr, u)
st.Set("wrong", "not uintptr")
u, ok = st.GetUintptr("wrong")
require.False(t, ok)
require.Equal(t, uintptr(0), u)
u, ok = st.GetUintptr("missing")
require.False(t, ok)
require.Equal(t, uintptr(0), u)
}
func TestState_GetFloat32(t *testing.T) {
t.Parallel()
st := newState()
st.Set("float32", float32(3.14))
f, ok := st.GetFloat32("float32")
require.True(t, ok)
require.InDelta(t, float32(3.14), f, 0.0001)
st.Set("wrong", "not float32")
f, ok = st.GetFloat32("wrong")
require.False(t, ok)
require.InDelta(t, float32(0), f, 0.0001)
f, ok = st.GetFloat32("missing")
require.False(t, ok)
require.InDelta(t, float32(0), f, 0.0001)
}
func TestState_GetComplex64(t *testing.T) {
t.Parallel()
st := newState()
var c complex64 = complex(2, 3)
st.Set("complex64", c)
cRes, ok := st.GetComplex64("complex64")
require.True(t, ok)
require.Equal(t, c, cRes)
st.Set("wrong", "not complex64")
cRes, ok = st.GetComplex64("wrong")
require.False(t, ok)
require.Equal(t, complex64(0), cRes)
cRes, ok = st.GetComplex64("missing")
require.False(t, ok)
require.Equal(t, complex64(0), cRes)
}
func TestState_GetComplex128(t *testing.T) {
t.Parallel()
st := newState()
c := complex(4, 5)
st.Set("complex128", c)
cRes, ok := st.GetComplex128("complex128")
require.True(t, ok)
require.Equal(t, c, cRes)
st.Set("wrong", "not complex128")
cRes, ok = st.GetComplex128("wrong")
require.False(t, ok)
require.Equal(t, complex128(0), cRes)
cRes, ok = st.GetComplex128("missing")
require.False(t, ok)
require.Equal(t, complex128(0), cRes)
}
func TestState_MustGet(t *testing.T) {
t.Parallel()
st := newState()
st.Set("exists", "value")
val := st.MustGet("exists")
require.Equal(t, "value", val)
// must-get on missing key should panic
require.Panics(t, func() {
_ = st.MustGet("missing")
})
}
func TestState_Has(t *testing.T) {
t.Parallel()
st := newState()
st.Set("key", "value")
require.True(t, st.Has("key"))
}
func TestState_Delete(t *testing.T) {
t.Parallel()
st := newState()
st.Set("key", "value")
st.Delete("key")
_, ok := st.Get("key")
require.False(t, ok)
}
func TestState_Reset(t *testing.T) {
t.Parallel()
st := newState()
st.Set("a", 1)
st.Set("b", 2)
st.Reset()
require.Equal(t, 0, st.Len())
require.Empty(t, st.Keys())
}
func TestState_Keys(t *testing.T) {
t.Parallel()
st := newState()
keys := []string{"one", "two", "three"}
for _, k := range keys {
st.Set(k, k)
}
returnedKeys := st.Keys()
require.ElementsMatch(t, keys, returnedKeys)
}
func TestState_Len(t *testing.T) {
t.Parallel()
st := newState()
require.Equal(t, 0, st.Len())
st.Set("a", "a")
require.Equal(t, 1, st.Len())
st.Set("b", "b")
require.Equal(t, 2, st.Len())
st.Delete("a")
require.Equal(t, 1, st.Len())
}
type testCase[T any] struct { //nolint:govet // It does not really matter for test
name string
key string
value any
expected T
ok bool
}
func runGenericTest[T any](t *testing.T, getter func(*State, string) (T, bool), tests []testCase[T]) {
t.Helper()
st := newState()
for _, tc := range tests {
st.Set(tc.key, tc.value)
got, ok := getter(st, tc.key)
require.Equal(t, tc.ok, ok, tc.name)
require.Equal(t, tc.expected, got, tc.name)
}
}
func TestState_GetGeneric(t *testing.T) {
t.Parallel()
runGenericTest[int](t, GetState[int], []testCase[int]{
{"int correct conversion", "num", 42, 42, true},
{"int wrong conversion from string", "str", "abc", 0, false},
})
runGenericTest[string](t, GetState[string], []testCase[string]{
{"string correct conversion", "strVal", "hello", "hello", true},
{"string wrong conversion from int", "intVal", 100, "", false},
})
runGenericTest[bool](t, GetState[bool], []testCase[bool]{
{"bool correct conversion", "flag", true, true, true},
{"bool wrong conversion from int", "intFlag", 1, false, false},
})
runGenericTest[float64](t, GetState[float64], []testCase[float64]{
{"float64 correct conversion", "pi", 3.14, 3.14, true},
{"float64 wrong conversion from int", "intVal", 10, 0.0, false},
})
}
func Test_MustGetStateGeneric(t *testing.T) {
t.Parallel()
st := newState()
st.Set("flag", true)
flag := MustGetState[bool](st, "flag")
require.True(t, flag)
// mismatched type should panic
require.Panics(t, func() {
_ = MustGetState[string](st, "flag")
})
// missing key should also panic
require.Panics(t, func() {
_ = MustGetState[string](st, "missing")
})
}
func Test_GetStateWithDefault(t *testing.T) {
t.Parallel()
st := newState()
st.Set("flag", true)
flag := GetStateWithDefault(st, "flag", false)
require.True(t, flag)
// mismatched type should return the default value
str := GetStateWithDefault(st, "flag", "default")
require.Equal(t, "default", str)
// missing key should return the default value
flag = GetStateWithDefault(st, "missing", false)
require.False(t, flag)
}
func BenchmarkState_Set(b *testing.B) {
b.ReportAllocs()
st := newState()
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, i)
}
}
func BenchmarkState_Get(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// pre-populate the state
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.Get(key)
}
}
func BenchmarkState_GetString(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// pre-populate the state
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, strconv.Itoa(i))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetString(key)
}
}
func BenchmarkState_GetInt(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// pre-populate the state
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetInt(key)
}
}
func BenchmarkState_GetBool(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// pre-populate the state
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, i%2 == 0)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetBool(key)
}
}
func BenchmarkState_GetFloat64(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// pre-populate the state
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, float64(i))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetFloat64(key)
}
}
func BenchmarkState_MustGet(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// pre-populate the state
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.MustGet(key)
}
}
func BenchmarkState_GetStateGeneric(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// pre-populate the state
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
GetState[int](st, key)
}
}
func BenchmarkState_MustGetStateGeneric(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// pre-populate the state
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
MustGetState[int](st, key)
}
}
func BenchmarkState_GetStateWithDefault(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// pre-populate the state
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
GetStateWithDefault[int](st, key, 0)
}
}
func BenchmarkState_Has(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// pre-populate the state
for i := 0; i < n; i++ {
st.Set("key"+strconv.Itoa(i), i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
st.Has("key" + strconv.Itoa(i%n))
}
}
func BenchmarkState_Delete(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
st := newState()
st.Set("a", 1)
st.Delete("a")
}
}
func BenchmarkState_Reset(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
st := newState()
// add a fixed number of keys before clearing
for j := 0; j < 100; j++ {
st.Set("key"+strconv.Itoa(j), j)
}
st.Reset()
}
}
func BenchmarkState_Keys(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
for i := 0; i < n; i++ {
st.Set("key"+strconv.Itoa(i), i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = st.Keys()
}
}
func BenchmarkState_Len(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
for i := 0; i < n; i++ {
st.Set("key"+strconv.Itoa(i), i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = st.Len()
}
}
func BenchmarkState_GetUint(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Pre-populate the state with uint values.
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, uint(i)) //nolint:gosec // This is a test
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetUint(key)
}
}
func BenchmarkState_GetInt8(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Pre-populate the state with int8 values (using modulo to stay in range).
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, int8(i%128)) //nolint:gosec // This is a test
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetInt8(key)
}
}
func BenchmarkState_GetInt16(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Pre-populate the state with int16 values.
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, int16(i)) //nolint:gosec // This is a test
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetInt16(key)
}
}
func BenchmarkState_GetInt32(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Pre-populate the state with int32 values.
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, int32(i)) //nolint:gosec // This is a test
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetInt32(key)
}
}
func BenchmarkState_GetInt64(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Pre-populate the state with int64 values.
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, int64(i))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetInt64(key)
}
}
func BenchmarkState_GetUint8(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Pre-populate the state with uint8 values.
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, uint8(i%256)) //nolint:gosec // This is a test
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetUint8(key)
}
}
func BenchmarkState_GetUint16(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Pre-populate the state with uint16 values.
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, uint16(i)) //nolint:gosec // This is a test
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetUint16(key)
}
}
func BenchmarkState_GetUint32(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Pre-populate the state with uint32 values.
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, uint32(i)) //nolint:gosec // This is a test
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetUint32(key)
}
}
func BenchmarkState_GetUint64(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Pre-populate the state with uint64 values.
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, uint64(i)) //nolint:gosec // This is a test
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetUint64(key)
}
}
func BenchmarkState_GetUintptr(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Pre-populate the state with uintptr values.
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, uintptr(i))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetUintptr(key)
}
}
func BenchmarkState_GetFloat32(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Pre-populate the state with float32 values.
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
st.Set(key, float32(i))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetFloat32(key)
}
}
func BenchmarkState_GetComplex64(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Pre-populate the state with complex64 values.
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
// Create a complex64 value with both real and imaginary parts.
st.Set(key, complex(float32(i), float32(i)))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetComplex64(key)
}
}
func BenchmarkState_GetComplex128(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Pre-populate the state with complex128 values.
for i := 0; i < n; i++ {
key := "key" + strconv.Itoa(i)
// Create a complex128 value with both real and imaginary parts.
st.Set(key, complex(float64(i), float64(i)))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := "key" + strconv.Itoa(i%n)
st.GetComplex128(key)
}
}