diff --git a/ctx.go b/ctx.go index 4239a601..8af51056 100644 --- a/ctx.go +++ b/ctx.go @@ -1005,7 +1005,7 @@ func (c *DefaultCtx) Protocol() string { // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting to use the value outside the Handler. func (c *DefaultCtx) Query(key string, defaultValue ...string) string { - return defaultString(c.app.getString(c.fasthttp.QueryArgs().Peek(key)), defaultValue) + return Query[string](c, key, defaultValue...) } // Queries returns a map of query parameters and their values. @@ -1037,67 +1037,98 @@ func (c *DefaultCtx) Queries() map[string]string { return m } -// QueryInt returns integer value of key string parameter in the url. -// Default to empty or invalid key is 0. +// Query Retrieves the value of a query parameter from the request's URI. +// The function is generic and can handle query parameter values of different types. +// It takes the following parameters: +// - c: The context object representing the current request. +// - key: The name of the query parameter. +// - defaultValue: (Optional) The default value to return in case the query parameter is not found or cannot be parsed. +// The function performs the following steps: +// 1. Type-asserts the context object to *DefaultCtx. +// 2. Retrieves the raw query parameter value from the request's URI. +// 3. Parses the raw value into the appropriate type based on the generic type parameter V. +// If parsing fails, the function checks if a default value is provided. If so, it returns the default value. +// 4. Returns the parsed value. // -// GET /?name=alex&wanna_cake=2&id= -// QueryInt("wanna_cake", 1) == 2 -// QueryInt("name", 1) == 1 -// QueryInt("id", 1) == 1 -// QueryInt("id") == 0 -func (c *DefaultCtx) QueryInt(key string, defaultValue ...int) int { - // Use Atoi to convert the param to an int or return zero and an error - value, err := strconv.Atoi(c.app.getString(c.fasthttp.QueryArgs().Peek(key))) - if err != nil { +// If the generic type cannot be matched to a supported type, the function returns the default value (if provided) or the zero value of type V. +// +// Example usage: +// +// GET /?search=john&age=8 +// name := Query[string](c, "search") // Returns "john" +// age := Query[int](c, "age") // Returns 8 +// unknown := Query[string](c, "unknown", "default") // Returns "default" since the query parameter "unknown" is not found +func Query[V QueryType](c Ctx, key string, defaultValue ...V) V { + ctx, ok := c.(*DefaultCtx) + if !ok { + panic(fmt.Errorf("failed to type-assert to *DefaultCtx")) + } + var v V + q := ctx.app.getString(ctx.fasthttp.QueryArgs().Peek(key)) + + switch any(v).(type) { + case int: + return queryParseInt[V](q, 32, func(i int64) V { return assertValueType[V, int](int(i)) }, defaultValue...) + case int8: + return queryParseInt[V](q, 8, func(i int64) V { return assertValueType[V, int8](int8(i)) }, defaultValue...) + case int16: + return queryParseInt[V](q, 16, func(i int64) V { return assertValueType[V, int16](int16(i)) }, defaultValue...) + case int32: + return queryParseInt[V](q, 32, func(i int64) V { return assertValueType[V, int32](int32(i)) }, defaultValue...) + case int64: + return queryParseInt[V](q, 64, func(i int64) V { return assertValueType[V, int64](i) }, defaultValue...) + case uint: + return queryParseUint[V](q, 32, func(i uint64) V { return assertValueType[V, uint](uint(i)) }, defaultValue...) + case uint8: + return queryParseUint[V](q, 8, func(i uint64) V { return assertValueType[V, uint8](uint8(i)) }, defaultValue...) + case uint16: + return queryParseUint[V](q, 16, func(i uint64) V { return assertValueType[V, uint16](uint16(i)) }, defaultValue...) + case uint32: + return queryParseUint[V](q, 32, func(i uint64) V { return assertValueType[V, uint32](uint32(i)) }, defaultValue...) + case uint64: + return queryParseUint[V](q, 64, func(i uint64) V { return assertValueType[V, uint64](i) }, defaultValue...) + case float32: + return queryParseFloat[V](q, 32, func(i float64) V { return assertValueType[V, float32](float32(i)) }, defaultValue...) + case float64: + return queryParseFloat[V](q, 64, func(i float64) V { return assertValueType[V, float64](i) }, defaultValue...) + case bool: + return queryParseBool[V](q, func(b bool) V { return assertValueType[V, bool](b) }, defaultValue...) + case string: + if q == "" && len(defaultValue) > 0 { + return defaultValue[0] + } + return assertValueType[V, string](q) + case []byte: + if q == "" && len(defaultValue) > 0 { + return defaultValue[0] + } + return assertValueType[V, []byte](ctx.app.getBytes(q)) + default: if len(defaultValue) > 0 { return defaultValue[0] } - return 0 + return v } - - return value } -// QueryBool returns bool value of key string parameter in the url. -// Default to empty or invalid key is true. -// -// Get /?name=alex&want_pizza=false&id= -// QueryBool("want_pizza") == false -// QueryBool("want_pizza", true) == false -// QueryBool("name") == false -// QueryBool("name", true) == true -// QueryBool("id") == false -// QueryBool("id", true) == true -func (c *DefaultCtx) QueryBool(key string, defaultValue ...bool) bool { - value, err := strconv.ParseBool(c.app.getString(c.fasthttp.QueryArgs().Peek(key))) - if err != nil { - if len(defaultValue) > 0 { - return defaultValue[0] - } - return false - } - return value +type QueryType interface { + QueryTypeInteger | QueryTypeFloat | bool | string | []byte } -// QueryFloat returns float64 value of key string parameter in the url. -// Default to empty or invalid key is 0. -// -// GET /?name=alex&amount=32.23&id= -// QueryFloat("amount") = 32.23 -// QueryFloat("amount", 3) = 32.23 -// QueryFloat("name", 1) = 1 -// QueryFloat("name") = 0 -// QueryFloat("id", 3) = 3 -func (c *DefaultCtx) QueryFloat(key string, defaultValue ...float64) float64 { - // use strconv.ParseFloat to convert the param to a float or return zero and an error. - value, err := strconv.ParseFloat(c.app.getString(c.fasthttp.QueryArgs().Peek(key)), 64) - if err != nil { - if len(defaultValue) > 0 { - return defaultValue[0] - } - return 0 - } - return value +type QueryTypeInteger interface { + QueryTypeIntegerSigned | QueryTypeIntegerUnsigned +} + +type QueryTypeIntegerSigned interface { + int | int8 | int16 | int32 | int64 +} + +type QueryTypeIntegerUnsigned interface { + uint | uint8 | uint16 | uint32 | uint64 +} + +type QueryTypeFloat interface { + float32 | float64 } // Range returns a struct containing the type and a slice of ranges. diff --git a/ctx_interface.go b/ctx_interface.go index a3a6fc88..ae8b1712 100644 --- a/ctx_interface.go +++ b/ctx_interface.go @@ -232,13 +232,6 @@ type Ctx interface { // Protocol returns the HTTP protocol of request: HTTP/1.1 and HTTP/2. Protocol() string - // Query returns the query string parameter in the url. - // Defaults to empty string "" if the query doesn't exist. - // If a default value is given, it will return that value if the query doesn't exist. - // Returned value is only valid within the handler. Do not store any references. - // Make copies or use the Immutable setting to use the value outside the Handler. - Query(key string, defaultValue ...string) string - // Queries returns a map of query parameters and their values. // // GET /?name=alex&wanna_cake=2&id= @@ -262,38 +255,12 @@ type Ctx interface { // Queries()["filters[status]"] == "pending" Queries() map[string]string - // QueryInt returns integer value of key string parameter in the url. - // Default to empty or invalid key is 0. - // - // GET /?name=alex&wanna_cake=2&id= - // QueryInt("wanna_cake", 1) == 2 - // QueryInt("name", 1) == 1 - // QueryInt("id", 1) == 1 - // QueryInt("id") == 0 - QueryInt(key string, defaultValue ...int) int - - // QueryBool returns bool value of key string parameter in the url. - // Default to empty or invalid key is true. - // - // Get /?name=alex&want_pizza=false&id= - // QueryBool("want_pizza") == false - // QueryBool("want_pizza", true) == false - // QueryBool("name") == false - // QueryBool("name", true) == true - // QueryBool("id") == false - // QueryBool("id", true) == true - QueryBool(key string, defaultValue ...bool) bool - - // QueryFloat returns float64 value of key string parameter in the url. - // Default to empty or invalid key is 0. - // - // GET /?name=alex&amount=32.23&id= - // QueryFloat("amount") = 32.23 - // QueryFloat("amount", 3) = 32.23 - // QueryFloat("name", 1) = 1 - // QueryFloat("name") = 0 - // QueryFloat("id", 3) = 3 - QueryFloat(key string, defaultValue ...float64) float64 + // Query returns the query string parameter in the url. + // Defaults to empty string "" if the query doesn't exist. + // If a default value is given, it will return that value if the query doesn't exist. + // Returned value is only valid within the handler. Do not store any references. + // Make copies or use the Immutable setting to use the value outside the Handler. + Query(key string, defaultValue ...string) string // Range returns a struct containing the type and a slice of ranges. Range(size int) (rangeData Range, err error) diff --git a/ctx_test.go b/ctx_test.go index ad9eafb8..afc29735 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -613,7 +613,7 @@ func Test_Ctx_UserContext_Multiple_Requests(t *testing.T) { return c.SendStatus(StatusInternalServerError) } - input := utils.CopyString(c.Query("input", "NO_VALUE")) + input := utils.CopyString(Query(c, "input", "NO_VALUE")) ctx = context.WithValue(ctx, testKey, fmt.Sprintf("%s_%s", testValue, input)) c.SetUserContext(ctx) @@ -2231,37 +2231,272 @@ func Test_Ctx_Query(t *testing.T) { require.Equal(t, "john", c.Query("search")) require.Equal(t, "20", c.Query("age")) require.Equal(t, "default", c.Query("unknown", "default")) + + // test with generic + require.Equal(t, "john", Query[string](c, "search")) + require.Equal(t, "20", Query[string](c, "age")) + require.Equal(t, "default", Query[string](c, "unknown", "default")) } -func Test_Ctx_QueryInt(t *testing.T) { +// go test -v -run=^$ -bench=Benchmark_Ctx_Query -benchmem -count=4 +func Benchmark_Ctx_Query(b *testing.B) { + app := New() + c := app.NewCtx(&fasthttp.RequestCtx{}) + c.Request().URI().SetQueryString("search=john&age=8") + var res string + b.ReportAllocs() + b.ResetTimer() + for n := 0; n < b.N; n++ { + res = Query[string](c, "search") + } + require.Equal(b, "john", res) +} + +// go test -run Test_Ctx_QuerySignedInt +func Test_Ctx_QuerySignedInt(t *testing.T) { t.Parallel() app := New() c := app.NewCtx(&fasthttp.RequestCtx{}) - c.Request().URI().SetQueryString("search=john&age=20&id=") - require.Equal(t, 0, c.QueryInt("foo")) - require.Equal(t, 20, c.QueryInt("age", 12)) - require.Equal(t, 0, c.QueryInt("search")) - require.Equal(t, 1, c.QueryInt("search", 1)) - require.Equal(t, 0, c.QueryInt("id")) - require.Equal(t, 2, c.QueryInt("id", 2)) + c.Request().URI().SetQueryString("search=john&age=8") + // int + require.Equal(t, 0, Query[int](c, "foo")) + require.Equal(t, 8, Query[int](c, "age", 12)) + require.Equal(t, 0, Query[int](c, "search")) + require.Equal(t, 1, Query[int](c, "search", 1)) + require.Equal(t, 0, Query[int](c, "id")) + require.Equal(t, 2, Query[int](c, "id", 2)) + + // int8 + require.Equal(t, int8(0), Query[int8](c, "foo")) + require.Equal(t, int8(8), Query[int8](c, "age", 12)) + require.Equal(t, int8(0), Query[int8](c, "search")) + require.Equal(t, int8(1), Query[int8](c, "search", 1)) + require.Equal(t, int8(0), Query[int8](c, "id")) + require.Equal(t, int8(2), Query[int8](c, "id", 2)) + + // int16 + require.Equal(t, int16(0), Query[int16](c, "foo")) + require.Equal(t, int16(8), Query[int16](c, "age", 12)) + require.Equal(t, int16(0), Query[int16](c, "search")) + require.Equal(t, int16(1), Query[int16](c, "search", 1)) + require.Equal(t, int16(0), Query[int16](c, "id")) + require.Equal(t, int16(2), Query[int16](c, "id", 2)) + + // int32 + require.Equal(t, int32(0), Query[int32](c, "foo")) + require.Equal(t, int32(8), Query[int32](c, "age", 12)) + require.Equal(t, int32(0), Query[int32](c, "search")) + require.Equal(t, int32(1), Query[int32](c, "search", 1)) + require.Equal(t, int32(0), Query[int32](c, "id")) + require.Equal(t, int32(2), Query[int32](c, "id", 2)) + + // int64 + require.Equal(t, int64(0), Query[int64](c, "foo")) + require.Equal(t, int64(8), Query[int64](c, "age", 12)) + require.Equal(t, int64(0), Query[int64](c, "search")) + require.Equal(t, int64(1), Query[int64](c, "search", 1)) + require.Equal(t, int64(0), Query[int64](c, "id")) + require.Equal(t, int64(2), Query[int64](c, "id", 2)) } -func Test_Ctx_QueryBool(t *testing.T) { +// go test -v -run=^$ -bench=Benchmark_Ctx_QuerySignedInt -benchmem -count=4 +func Benchmark_Ctx_QuerySignedInt(b *testing.B) { + app := New() + c := app.NewCtx(&fasthttp.RequestCtx{}) + c.Request().URI().SetQueryString("search=john&age=8") + var res int + b.ReportAllocs() + b.ResetTimer() + for n := 0; n < b.N; n++ { + res = Query[int](c, "age") + } + require.Equal(b, 8, res) +} + +// go test -run Test_Ctx_QueryBoundarySignedInt +func Test_Ctx_QueryBoundarySignedInt(t *testing.T) { + t.Parallel() + app := New() + c := app.NewCtx(&fasthttp.RequestCtx{}) + var q string + + // int + q = fmt.Sprintf("minus=%s&plus=%s&minus_over=%s&plus_over=%s", "2147483647", "-2147483648", "-2147483649", "2147483648") + c.Request().URI().SetQueryString(q) + require.Equal(t, 2147483647, Query[int](c, "minus")) + require.Equal(t, -2147483648, Query[int](c, "plus")) + require.Equal(t, 0, Query[int](c, "minus_over")) + require.Equal(t, 0, Query[int](c, "plus_over")) + + // int8 + q = fmt.Sprintf("minus=%s&plus=%s&minus_over=%s&plus_over=%s", "127", "-128", "-129", "128") + c.Request().URI().SetQueryString(q) + require.Equal(t, int8(127), Query[int8](c, "minus")) + require.Equal(t, int8(-128), Query[int8](c, "plus")) + require.Equal(t, int8(0), Query[int8](c, "minus_over")) + require.Equal(t, int8(0), Query[int8](c, "plus_over")) + + // int16 + q = fmt.Sprintf("minus=%s&plus=%s&minus_over=%s&plus_over=%s", "32767", "-32768", "-32769", "32768") + c.Request().URI().SetQueryString(q) + require.Equal(t, int16(32767), Query[int16](c, "minus")) + require.Equal(t, int16(-32768), Query[int16](c, "plus")) + require.Equal(t, int16(0), Query[int16](c, "minus_over")) + require.Equal(t, int16(0), Query[int16](c, "plus_over")) + + // int32 + q = fmt.Sprintf("minus=%s&plus=%s&minus_over=%s&plus_over=%s", "2147483647", "-2147483648", "-2147483649", "2147483648") + c.Request().URI().SetQueryString(q) + require.Equal(t, int32(2147483647), Query[int32](c, "minus")) + require.Equal(t, int32(-2147483648), Query[int32](c, "plus")) + require.Equal(t, int32(0), Query[int32](c, "minus_over")) + require.Equal(t, int32(0), Query[int32](c, "plus_over")) + + // int64 + q = fmt.Sprintf("minus=%s&plus=%s", "-9223372036854775808", "9223372036854775807") + c.Request().URI().SetQueryString(q) + require.Equal(t, int64(-9223372036854775808), Query[int64](c, "minus")) + require.Equal(t, int64(9223372036854775807), Query[int64](c, "plus")) +} + +// go test -v -run=^$ -bench=Benchmark_Ctx_QueryBoundarySignedInt -benchmem -count=4 +func Benchmark_Ctx_QueryBoundarySignedInt(b *testing.B) { + app := New() + c := app.NewCtx(&fasthttp.RequestCtx{}) + c.Request().URI().SetQueryString("search=john&age=8") + var res int + b.ReportAllocs() + b.ResetTimer() + for n := 0; n < b.N; n++ { + res = Query[int](c, "age") + } + require.Equal(b, 8, res) +} + +// go test -run Test_Ctx_QueryUnsignedInt +func Test_Ctx_QueryUnsignedInt(t *testing.T) { t.Parallel() app := New() c := app.NewCtx(&fasthttp.RequestCtx{}) - c.Request().URI().SetQueryString("name=alex&want_pizza=false&id=") + c.Request().URI().SetQueryString("search=john&age=8") + // uint + require.Equal(t, uint(0), Query[uint](c, "foo")) + require.Equal(t, uint(8), Query[uint](c, "age", 12)) + require.Equal(t, uint(0), Query[uint](c, "search")) + require.Equal(t, uint(1), Query[uint](c, "search", 1)) + require.Equal(t, uint(0), Query[uint](c, "id")) + require.Equal(t, uint(2), Query[uint](c, "id", 2)) - require.Equal(t, false, c.QueryBool("want_pizza")) - require.Equal(t, false, c.QueryBool("want_pizza", true)) - require.Equal(t, false, c.QueryBool("name")) - require.Equal(t, true, c.QueryBool("name", true)) - require.Equal(t, false, c.QueryBool("id")) - require.Equal(t, true, c.QueryBool("id", true)) + // uint8 + require.Equal(t, uint8(0), Query[uint8](c, "foo")) + require.Equal(t, uint8(8), Query[uint8](c, "age", 12)) + require.Equal(t, uint8(0), Query[uint8](c, "search")) + require.Equal(t, uint8(1), Query[uint8](c, "search", 1)) + require.Equal(t, uint8(0), Query[uint8](c, "id")) + require.Equal(t, uint8(2), Query[uint8](c, "id", 2)) + + // uint16 + require.Equal(t, uint16(0), Query[uint16](c, "foo")) + require.Equal(t, uint16(8), Query[uint16](c, "age", 12)) + require.Equal(t, uint16(0), Query[uint16](c, "search")) + require.Equal(t, uint16(1), Query[uint16](c, "search", 1)) + require.Equal(t, uint16(0), Query[uint16](c, "id")) + require.Equal(t, uint16(2), Query[uint16](c, "id", 2)) + + // uint32 + require.Equal(t, uint32(0), Query[uint32](c, "foo")) + require.Equal(t, uint32(8), Query[uint32](c, "age", 12)) + require.Equal(t, uint32(0), Query[uint32](c, "search")) + require.Equal(t, uint32(1), Query[uint32](c, "search", 1)) + require.Equal(t, uint32(0), Query[uint32](c, "id")) + require.Equal(t, uint32(2), Query[uint32](c, "id", 2)) + + // uint64 + require.Equal(t, uint64(0), Query[uint64](c, "foo")) + require.Equal(t, uint64(8), Query[uint64](c, "age", 12)) + require.Equal(t, uint64(0), Query[uint64](c, "search")) + require.Equal(t, uint64(1), Query[uint64](c, "search", 1)) + require.Equal(t, uint64(0), Query[uint64](c, "id")) + require.Equal(t, uint64(2), Query[uint64](c, "id", 2)) } +// go test -v -run=^$ -bench=Benchmark_Ctx_QueryUnsignedInt -benchmem -count=4 +func Benchmark_Ctx_QueryUnsignedInt(b *testing.B) { + app := New() + c := app.NewCtx(&fasthttp.RequestCtx{}) + c.Request().URI().SetQueryString("search=john&age=8") + var res uint + b.ReportAllocs() + b.ResetTimer() + for n := 0; n < b.N; n++ { + res = Query[uint](c, "age") + } + require.Equal(b, uint(8), res) +} + +// go test -run Test_Ctx_QueryBoundaryUnsignedInt +func Test_Ctx_QueryBoundaryUnsignedInt(t *testing.T) { + t.Parallel() + app := New() + c := app.NewCtx(&fasthttp.RequestCtx{}) + var q string + + // uint + q = fmt.Sprintf("minus=%s&plus=%s&minus_over=%s&plus_over=%s", "0", "4294967295", "4294967296", "4294967297") + c.Request().URI().SetQueryString(q) + require.Equal(t, uint(0), Query[uint](c, "minus")) + require.Equal(t, uint(4294967295), Query[uint](c, "plus")) + require.Equal(t, uint(0), Query[uint](c, "minus_over")) + require.Equal(t, uint(0), Query[uint](c, "plus_over")) + + // uint8 + q = fmt.Sprintf("minus=%s&plus=%s&minus_over=%s&plus_over=%s", "0", "255", "256", "257") + c.Request().URI().SetQueryString(q) + require.Equal(t, uint8(0), Query[uint8](c, "minus")) + require.Equal(t, uint8(255), Query[uint8](c, "plus")) + require.Equal(t, uint8(0), Query[uint8](c, "minus_over")) + require.Equal(t, uint8(0), Query[uint8](c, "plus_over")) + + // uint16 + q = fmt.Sprintf("minus=%s&plus=%s&minus_over=%s&plus_over=%s", "0", "65535", "65536", "65537") + c.Request().URI().SetQueryString(q) + require.Equal(t, uint16(0), Query[uint16](c, "minus")) + require.Equal(t, uint16(65535), Query[uint16](c, "plus")) + require.Equal(t, uint16(0), Query[uint16](c, "minus_over")) + require.Equal(t, uint16(0), Query[uint16](c, "plus_over")) + + // uint32 + q = fmt.Sprintf("minus=%s&plus=%s&minus_over=%s&plus_over=%s", "0", "4294967295", "4294967296", "4294967297") + c.Request().URI().SetQueryString(q) + require.Equal(t, uint32(0), Query[uint32](c, "minus")) + require.Equal(t, uint32(4294967295), Query[uint32](c, "plus")) + require.Equal(t, uint32(0), Query[uint32](c, "minus_over")) + require.Equal(t, uint32(0), Query[uint32](c, "plus_over")) + + // uint64 + q = fmt.Sprintf("minus=%s&plus=%s", "0", "18446744073709551615") + c.Request().URI().SetQueryString(q) + require.Equal(t, uint64(0), Query[uint64](c, "minus")) + require.Equal(t, uint64(18446744073709551615), Query[uint64](c, "plus")) +} + +// go test -v -run=^$ -bench=Benchmark_Ctx_QueryBoundaryUnsignedInt -benchmem -count=4 +func Benchmark_Ctx_QueryBoundaryUnsignedInt(b *testing.B) { + app := New() + c := app.NewCtx(&fasthttp.RequestCtx{}) + c.Request().URI().SetQueryString("search=john&age=8") + var res uint + b.ReportAllocs() + b.ResetTimer() + for n := 0; n < b.N; n++ { + res = Query[uint](c, "age") + } + require.Equal(b, uint(8), res) +} + +// go test -run Test_Ctx_QueryFloat func Test_Ctx_QueryFloat(t *testing.T) { t.Parallel() app := New() @@ -2269,12 +2504,179 @@ func Test_Ctx_QueryFloat(t *testing.T) { c.Request().URI().SetQueryString("name=alex&amount=32.23&id=") - require.Equal(t, 32.23, c.QueryFloat("amount")) - require.Equal(t, 32.23, c.QueryFloat("amount", 3.123)) - require.Equal(t, 87.123, c.QueryFloat("name", 87.123)) - require.Equal(t, float64(0), c.QueryFloat("name")) - require.Equal(t, 12.87, c.QueryFloat("id", 12.87)) - require.Equal(t, float64(0), c.QueryFloat("id")) + // float32 + require.Equal(t, float32(32.23), Query[float32](c, "amount")) + require.Equal(t, float32(32.23), Query[float32](c, "amount", 3.123)) + require.Equal(t, float32(87.123), Query[float32](c, "name", 87.123)) + require.Equal(t, float32(0), Query[float32](c, "name")) + require.Equal(t, float32(12.87), Query[float32](c, "id", 12.87)) + require.Equal(t, float32(0), Query[float32](c, "id")) + + // float64 + require.Equal(t, 32.23, Query[float64](c, "amount")) + require.Equal(t, 32.23, Query[float64](c, "amount", 3.123)) + require.Equal(t, 87.123, Query[float64](c, "name", 87.123)) + require.Equal(t, float64(0), Query[float64](c, "name")) + require.Equal(t, 12.87, Query[float64](c, "id", 12.87)) + require.Equal(t, float64(0), Query[float64](c, "id")) +} + +// go test -v -run=^$ -bench=Benchmark_Ctx_QueryFloat -benchmem -count=4 +func Benchmark_Ctx_QueryFloat(b *testing.B) { + app := New() + c := app.NewCtx(&fasthttp.RequestCtx{}) + c.Request().URI().SetQueryString("search=john&age=8") + var res float32 + b.ReportAllocs() + b.ResetTimer() + for n := 0; n < b.N; n++ { + res = Query[float32](c, "age") + } + require.Equal(b, float32(8), res) +} + +// go test -run Test_Ctx_QueryBool +func Test_Ctx_QueryBool(t *testing.T) { + t.Parallel() + app := New() + c := app.NewCtx(&fasthttp.RequestCtx{}) + + c.Request().URI().SetQueryString("name=alex&want_pizza=false&id=") + + require.Equal(t, false, Query[bool](c, "want_pizza")) + require.Equal(t, false, Query[bool](c, "want_pizza", true)) + require.Equal(t, false, Query[bool](c, "name")) + require.Equal(t, true, Query[bool](c, "name", true)) + require.Equal(t, false, Query[bool](c, "id")) + require.Equal(t, true, Query[bool](c, "id", true)) +} + +// go test -v -run=^$ -bench=Benchmark_Ctx_QueryBool -benchmem -count=4 +func Benchmark_Ctx_QueryBool(b *testing.B) { + app := New() + c := app.NewCtx(&fasthttp.RequestCtx{}) + c.Request().URI().SetQueryString("search=john&age=8") + var res bool + b.ReportAllocs() + b.ResetTimer() + for n := 0; n < b.N; n++ { + res = Query[bool](c, "age") + } + require.Equal(b, false, res) +} + +// go test -run Test_Ctx_QueryString +func Test_Ctx_QueryString(t *testing.T) { + t.Parallel() + app := New() + c := app.NewCtx(&fasthttp.RequestCtx{}) + + c.Request().URI().SetQueryString("name=alex&amount=32.23&id=") + + require.Equal(t, "alex", Query[string](c, "name")) + require.Equal(t, "alex", Query[string](c, "name", "john")) + require.Equal(t, "32.23", Query[string](c, "amount")) + require.Equal(t, "32.23", Query[string](c, "amount", "3.123")) + require.Equal(t, "", Query[string](c, "id")) + require.Equal(t, "12.87", Query[string](c, "id", "12.87")) +} + +// go test -v -run=^$ -bench=Benchmark_Ctx_QueryString -benchmem -count=4 +func Benchmark_Ctx_QueryString(b *testing.B) { + app := New() + c := app.NewCtx(&fasthttp.RequestCtx{}) + c.Request().URI().SetQueryString("search=john&age=8") + var res string + b.ReportAllocs() + b.ResetTimer() + for n := 0; n < b.N; n++ { + res = Query[string](c, "age") + } + require.Equal(b, "8", res) +} + +// go test -run Test_Ctx_QueryBytes +func Test_Ctx_QueryBytes(t *testing.T) { + t.Parallel() + app := New() + c := app.NewCtx(&fasthttp.RequestCtx{}) + + c.Request().URI().SetQueryString("name=alex&amount=32.23&id=") + + require.Equal(t, []byte("alex"), Query[[]byte](c, "name")) + require.Equal(t, []byte("alex"), Query[[]byte](c, "name", []byte("john"))) + require.Equal(t, []byte("32.23"), Query[[]byte](c, "amount")) + require.Equal(t, []byte("32.23"), Query[[]byte](c, "amount", []byte("3.123"))) + require.Equal(t, []byte(nil), Query[[]byte](c, "id")) + require.Equal(t, []byte("12.87"), Query[[]byte](c, "id", []byte("12.87"))) +} + +// go test -v -run=^$ -bench=Benchmark_Ctx_QueryBytes -benchmem -count=4 +func Benchmark_Ctx_QueryBytes(b *testing.B) { + app := New() + c := app.NewCtx(&fasthttp.RequestCtx{}) + c.Request().URI().SetQueryString("search=john&age=8") + var res []byte + b.ReportAllocs() + b.ResetTimer() + for n := 0; n < b.N; n++ { + res = Query[[]byte](c, "age") + } + require.Equal(b, []byte("8"), res) +} + +// go test -run Test_Ctx_QueryWithoutGenericDataType +func Test_Ctx_QueryWithoutGenericDataType(t *testing.T) { + t.Parallel() + app := New() + c := app.NewCtx(&fasthttp.RequestCtx{}) + + c.Request().URI().SetQueryString("name=alex&amount=32.23&isAgent=true&id=32") + + require.Equal(t, "alex", Query(c, "name", "john")) + require.Equal(t, "john", Query(c, "unknown", "john")) + require.Equal(t, 32, Query(c, "id", 3)) + require.Equal(t, 3, Query(c, "unknown", 3)) + require.Equal(t, int8(32), Query(c, "id", int8(3))) + require.Equal(t, int8(3), Query(c, "unknown", int8(3))) + require.Equal(t, int16(32), Query(c, "id", int16(3))) + require.Equal(t, int16(3), Query(c, "unknown", int16(3))) + require.Equal(t, int32(32), Query(c, "id", int32(3))) + require.Equal(t, int32(3), Query(c, "unknown", int32(3))) + require.Equal(t, int64(32), Query(c, "id", int64(3))) + require.Equal(t, int64(3), Query(c, "unknown", int64(3))) + require.Equal(t, uint(32), Query(c, "id", uint(3))) + require.Equal(t, uint(3), Query(c, "unknown", uint(3))) + require.Equal(t, uint8(32), Query(c, "id", uint8(3))) + require.Equal(t, uint8(3), Query(c, "unknown", uint8(3))) + require.Equal(t, uint16(32), Query(c, "id", uint16(3))) + require.Equal(t, uint16(3), Query(c, "unknown", uint16(3))) + require.Equal(t, uint32(32), Query(c, "id", uint32(3))) + require.Equal(t, uint32(3), Query(c, "unknown", uint32(3))) + require.Equal(t, uint64(32), Query(c, "id", uint64(3))) + require.Equal(t, uint64(3), Query(c, "unknown", uint64(3))) + require.Equal(t, 32.23, Query(c, "amount", 3.123)) + require.Equal(t, 3.123, Query(c, "unknown", 3.123)) + require.Equal(t, float32(32.23), Query(c, "amount", float32(3.123))) + require.Equal(t, float32(3.123), Query(c, "unknown", float32(3.123))) + require.Equal(t, true, Query(c, "isAgent", false)) + require.Equal(t, false, Query(c, "unknown", false)) + require.Equal(t, []byte("alex"), Query(c, "name", []byte("john"))) + require.Equal(t, []byte("john"), Query(c, "unknown", []byte("john"))) +} + +// go test -v -run=^$ -bench=Benchmark_Ctx_QueryWithoutGenericDataType -benchmem -count=4 +func Benchmark_Ctx_QueryWithoutGenericDataType(b *testing.B) { + app := New() + c := app.NewCtx(&fasthttp.RequestCtx{}) + c.Request().URI().SetQueryString("search=john&age=8") + var res int + b.ReportAllocs() + b.ResetTimer() + for n := 0; n < b.N; n++ { + res = Query(c, "age", 3) + } + require.Equal(b, 8, res) } // go test -run Test_Ctx_Range diff --git a/docs/api/ctx.md b/docs/api/ctx.md index 9323709e..ca0199e5 100644 --- a/docs/api/ctx.md +++ b/docs/api/ctx.md @@ -1349,94 +1349,42 @@ app.Get("/", func(c fiber.Ctx) error { > _Returned value is only valid within the handler. Do not store any references. > Make copies or use the_ [_**`Immutable`**_](ctx.md) _setting instead._ [_Read more..._](../#zero-allocation) -## QueryBool +In certain scenarios, it can be useful to have an alternative approach to handle different types of query parameters, not +just strings. This can be achieved using a generic Query function known as `Query[V QueryType](c Ctx, key string, defaultValue ...V) V`. +This function is capable of parsing a query string and returning a value of a type that is assumed and specified by `V QueryType`. -This property is an object containing a property for each query boolean parameter in the route, you could pass an optional default value that will be returned if the query key does not exist. - -:::caution -Please note if that parameter is not in the request, false will be returned. -If the parameter is not a boolean, it is still tried to be converted and usually returned as false. -::: +Here is the signature for the generic Query function: ```go title="Signature" -func (c *Ctx) QueryBool(key string, defaultValue ...bool) bool +func Query[V QueryType](c Ctx, key string, defaultValue ...V) V ``` +Consider this example: + ```go title="Example" -// GET http://example.com/?name=alex&want_pizza=false&id= +// GET http://example.com/?page=1&brand=nike&new=true app.Get("/", func(c fiber.Ctx) error { - c.QueryBool("want_pizza") // false - c.QueryBool("want_pizza", true) // false - c.QueryBool("name") // false - c.QueryBool("name", true) // true - c.QueryBool("id") // false - c.QueryBool("id", true) // true + fiber.Query[int](c, "page") // 1 + fiber.Query[string](c, "brand") // "nike" + fiber.Query[bool](c, "new") // true // ... }) ``` -## QueryFloat +In this case, `Query[V QueryType](c Ctx, key string, defaultValue ...V) V` can retrieve 'page' as an integer, 'brand' +as a string, and 'new' as a boolean. The function uses the appropriate parsing function for each specified type to ensure +the correct type is returned. This simplifies the retrieval process of different types of query parameters, making your +controller actions cleaner. -This property is an object containing a property for each query float64 parameter in the route, you could pass an optional default value that will be returned if the query key does not exist. - -:::caution -Please note if that parameter is not in the request, zero will be returned. -If the parameter is not a number, it is still tried to be converted and usually returned as 1. -::: - -:::info -Defaults to the float64 zero \(`0`\), if the param **doesn't** exist. -::: - -```go title="Signature" -func (c *Ctx) QueryFloat(key string, defaultValue ...float64) float64 -``` - -```go title="Example" -// GET http://example.com/?name=alex&amount=32.23&id= - -app.Get("/", func(c fiber.Ctx) error { - c.QueryFloat("amount") // 32.23 - c.QueryFloat("amount", 3) // 32.23 - c.QueryFloat("name", 1) // 1 - c.QueryFloat("name") // 0 - c.QueryFloat("id", 3) // 3 - - // ... -}) -``` - -## QueryInt - -This property is an object containing a property for each query integer parameter in the route, you could pass an optional default value that will be returned if the query key does not exist. - -:::caution -Please note if that parameter is not in the request, zero will be returned. -If the parameter is not a number, it is still tried to be converted and usually returned as 1. -::: - -:::info -Defaults to the integer zero \(`0`\), if the param **doesn't** exist. -::: - -```go title="Signature" -func (c *Ctx) QueryInt(key string, defaultValue ...int) int -``` - -```go title="Example" -// GET http://example.com/?name=alex&wanna_cake=2&id= - -app.Get("/", func(c fiber.Ctx) error { - c.QueryInt("wanna_cake", 1) // 2 - c.QueryInt("name", 1) // 1 - c.QueryInt("id", 1) // 1 - c.QueryInt("id") // 0 - - // ... -}) -``` +The generic Query function supports returning the following data types based on V QueryType: +- Integer: int, int8, int16, int32, int64 +- Unsigned integer: uint, uint8, uint16, uint32, uint64 +- Floating-point numbers: float32, float64 +- Boolean: bool +- String: string +- Byte array: []byte ## QueryParser diff --git a/docs/api/middleware/cache.md b/docs/api/middleware/cache.md index 4c731479..26022cf0 100644 --- a/docs/api/middleware/cache.md +++ b/docs/api/middleware/cache.md @@ -36,7 +36,7 @@ app.Use(cache.New()) // Or extend your config for customization app.Use(cache.New(cache.Config{ Next: func(c fiber.Ctx) bool { - return c.Query("noCache") == "true" + return fiber.Query[bool](c, "noCache") }, Expiration: 30 * time.Minute, CacheControl: true, diff --git a/docs/guide/validation.md b/docs/guide/validation.md index 4ed82b23..58131bc0 100644 --- a/docs/guide/validation.md +++ b/docs/guide/validation.md @@ -97,8 +97,8 @@ func main() { app.Get("/", func(c fiber.Ctx) error { user := &User{ - Name: c.Query("name"), - Age: c.QueryInt("age"), + Name: fiber.Query[string](c, "name"), + Age: fiber.Query[int](c, "age"), } // Validation diff --git a/middleware/cache/cache_test.go b/middleware/cache/cache_test.go index 1da4496f..e362b30d 100644 --- a/middleware/cache/cache_test.go +++ b/middleware/cache/cache_test.go @@ -116,7 +116,7 @@ func Test_Cache_WithNoCacheRequestDirective(t *testing.T) { app.Use(New()) app.Get("/", func(c fiber.Ctx) error { - return c.SendString(c.Query("id", "1")) + return c.SendString(fiber.Query(c, "id", "1")) }) // Request id = 1 @@ -184,7 +184,7 @@ func Test_Cache_WithETagAndNoCacheRequestDirective(t *testing.T) { ) app.Get("/", func(c fiber.Ctx) error { - return c.SendString(c.Query("id", "1")) + return c.SendString(fiber.Query(c, "id", "1")) }) // Request id = 1 @@ -247,7 +247,7 @@ func Test_Cache_WithNoStoreRequestDirective(t *testing.T) { app.Use(New()) app.Get("/", func(c fiber.Ctx) error { - return c.SendString(c.Query("id", "1")) + return c.SendString(fiber.Query(c, "id", "1")) }) // Request id = 2 @@ -335,11 +335,11 @@ func Test_Cache_Get(t *testing.T) { app.Use(New()) app.Post("/", func(c fiber.Ctx) error { - return c.SendString(c.Query("cache")) + return c.SendString(fiber.Query[string](c, "cache")) }) app.Get("/get", func(c fiber.Ctx) error { - return c.SendString(c.Query("cache")) + return c.SendString(fiber.Query[string](c, "cache")) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=123", nil)) @@ -377,11 +377,11 @@ func Test_Cache_Post(t *testing.T) { })) app.Post("/", func(c fiber.Ctx) error { - return c.SendString(c.Query("cache")) + return c.SendString(fiber.Query[string](c, "cache")) }) app.Get("/get", func(c fiber.Ctx) error { - return c.SendString(c.Query("cache")) + return c.SendString(fiber.Query[string](c, "cache")) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=123", nil)) @@ -589,7 +589,7 @@ func Test_CacheHeader(t *testing.T) { }) app.Post("/", func(c fiber.Ctx) error { - return c.SendString(c.Query("cache")) + return c.SendString(fiber.Query[string](c, "cache")) }) app.Get("/error", func(c fiber.Ctx) error { @@ -650,7 +650,7 @@ func Test_Cache_WithHeadThenGet(t *testing.T) { app.Use(New()) handler := func(c fiber.Ctx) error { - return c.SendString(c.Query("cache")) + return c.SendString(fiber.Query[string](c, "cache")) } app.Route("/").Get(handler).Head(handler) diff --git a/middleware/csrf/extractors.go b/middleware/csrf/extractors.go index 5021301b..1396d2ec 100644 --- a/middleware/csrf/extractors.go +++ b/middleware/csrf/extractors.go @@ -61,7 +61,7 @@ func CsrfFromHeader(param string) func(c fiber.Ctx) (string, error) { // csrfFromQuery returns a function that extracts token from the query string. func CsrfFromQuery(param string) func(c fiber.Ctx) (string, error) { return func(c fiber.Ctx) (string, error) { - token := c.Query(param) + token := fiber.Query[string](c, param) if token == "" { return "", ErrMissingQuery } diff --git a/middleware/keyauth/keyauth.go b/middleware/keyauth/keyauth.go index 86b9f894..914bca03 100644 --- a/middleware/keyauth/keyauth.go +++ b/middleware/keyauth/keyauth.go @@ -98,7 +98,7 @@ func keyFromHeader(header, authScheme string) func(c fiber.Ctx) (string, error) // keyFromQuery returns a function that extracts api key from the query string. func keyFromQuery(param string) func(c fiber.Ctx) (string, error) { return func(c fiber.Ctx) (string, error) { - key := c.Query(param) + key := fiber.Query[string](c, param) if key == "" { return "", ErrMissingOrMalformedAPIKey } diff --git a/middleware/logger/tags.go b/middleware/logger/tags.go index b4ce5cfd..0d8f3ad3 100644 --- a/middleware/logger/tags.go +++ b/middleware/logger/tags.go @@ -157,7 +157,7 @@ func createTagMap(cfg *Config) map[string]LogFunc { return output.WriteString(c.GetRespHeader(extraParam)) }, TagQuery: func(output Buffer, c fiber.Ctx, data *Data, extraParam string) (int, error) { - return output.WriteString(c.Query(extraParam)) + return output.WriteString(fiber.Query[string](c, extraParam)) }, TagForm: func(output Buffer, c fiber.Ctx, data *Data, extraParam string) (int, error) { return output.WriteString(c.FormValue(extraParam)) diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go index 1ec471bd..184f8bfa 100644 --- a/middleware/proxy/proxy_test.go +++ b/middleware/proxy/proxy_test.go @@ -666,7 +666,7 @@ func Test_Proxy_Domain_Forward_Local(t *testing.T) { app1 := fiber.New() app1.Get("/test", func(c fiber.Ctx) error { - return c.SendString("test_local_client:" + c.Query("query_test")) + return c.SendString("test_local_client:" + fiber.Query[string](c, "query_test")) }) proxyAddr := ln.Addr().String() diff --git a/middleware/session/store.go b/middleware/session/store.go index d178753a..ef4ecb44 100644 --- a/middleware/session/store.go +++ b/middleware/session/store.go @@ -110,7 +110,7 @@ func (s *Store) getSessionID(c fiber.Ctx) string { } if s.source == SourceURLQuery { - id = c.Query(s.sessionName) + id = fiber.Query[string](c, s.sessionName) if len(id) > 0 { return utils.CopyString(id) } diff --git a/utils.go b/utils.go new file mode 100644 index 00000000..996afb9a --- /dev/null +++ b/utils.go @@ -0,0 +1,46 @@ +package fiber + +import ( + "fmt" + "strconv" +) + +// assertValueType asserts the type of the result to the type of the value +func assertValueType[V QueryType, T any](result T) V { + v, ok := any(result).(V) + if !ok { + panic(fmt.Errorf("failed to type-assert to %T", v)) + } + return v +} + +func queryParseDefault[V QueryType](err error, parser func() V, defaultValue ...V) V { + var v V + if err != nil { + if len(defaultValue) > 0 { + return defaultValue[0] + } + return v + } + return parser() +} + +func queryParseInt[V QueryType](q string, bitSize int, parser func(int64) V, defaultValue ...V) V { + result, err := strconv.ParseInt(q, 10, bitSize) + return queryParseDefault[V](err, func() V { return parser(result) }, defaultValue...) +} + +func queryParseUint[V QueryType](q string, bitSize int, parser func(uint64) V, defaultValue ...V) V { + result, err := strconv.ParseUint(q, 10, bitSize) + return queryParseDefault[V](err, func() V { return parser(result) }, defaultValue...) +} + +func queryParseFloat[V QueryType](q string, bitSize int, parser func(float64) V, defaultValue ...V) V { + result, err := strconv.ParseFloat(q, bitSize) + return queryParseDefault[V](err, func() V { return parser(result) }, defaultValue...) +} + +func queryParseBool[V QueryType](q string, parser func(bool) V, defaultValue ...V) V { + result, err := strconv.ParseBool(q) + return queryParseDefault[V](err, func() V { return parser(result) }, defaultValue...) +}