mirror of https://github.com/gofiber/fiber.git
bind: add support for multipart file binding (#3309)
* deps: update schema to v1.3.0 * bind: add support for multipart file binding * bind: fix linter * improve coverage * fix linter * add test cases --------- Co-authored-by: René <rene@gofiber.io>pull/3334/head
parent
d6d48d8cb7
commit
bc4c920ea6
|
@ -1,6 +1,8 @@
|
|||
package binder
|
||||
|
||||
import (
|
||||
"mime/multipart"
|
||||
|
||||
"github.com/gofiber/utils/v2"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
@ -59,7 +61,15 @@ func (b *FormBinding) bindMultipart(req *fasthttp.Request, out any) error {
|
|||
}
|
||||
}
|
||||
|
||||
return parse(b.Name(), out, data)
|
||||
files := make(map[string][]*multipart.FileHeader)
|
||||
for key, values := range multipartForm.File {
|
||||
err = formatBindData(out, files, key, values, b.EnableSplitting, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return parse(b.Name(), out, data, files)
|
||||
}
|
||||
|
||||
// Reset resets the FormBinding binder.
|
||||
|
|
|
@ -2,6 +2,7 @@ package binder
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"testing"
|
||||
|
||||
|
@ -98,10 +99,12 @@ func Test_FormBinder_BindMultipart(t *testing.T) {
|
|||
}
|
||||
|
||||
type User struct {
|
||||
Name string `form:"name"`
|
||||
Names []string `form:"names"`
|
||||
Posts []Post `form:"posts"`
|
||||
Age int `form:"age"`
|
||||
Avatar *multipart.FileHeader `form:"avatar"`
|
||||
Name string `form:"name"`
|
||||
Names []string `form:"names"`
|
||||
Posts []Post `form:"posts"`
|
||||
Avatars []*multipart.FileHeader `form:"avatars"`
|
||||
Age int `form:"age"`
|
||||
}
|
||||
var user User
|
||||
|
||||
|
@ -118,6 +121,24 @@ func Test_FormBinder_BindMultipart(t *testing.T) {
|
|||
require.NoError(t, mw.WriteField("posts[1][title]", "post2"))
|
||||
require.NoError(t, mw.WriteField("posts[2][title]", "post3"))
|
||||
|
||||
writer, err := mw.CreateFormFile("avatar", "avatar.txt")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = writer.Write([]byte("avatar"))
|
||||
require.NoError(t, err)
|
||||
|
||||
writer, err = mw.CreateFormFile("avatars", "avatar1.txt")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = writer.Write([]byte("avatar1"))
|
||||
require.NoError(t, err)
|
||||
|
||||
writer, err = mw.CreateFormFile("avatars", "avatar2.txt")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = writer.Write([]byte("avatar2"))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, mw.Close())
|
||||
|
||||
req.Header.SetContentType(mw.FormDataContentType())
|
||||
|
@ -127,7 +148,7 @@ func Test_FormBinder_BindMultipart(t *testing.T) {
|
|||
fasthttp.ReleaseRequest(req)
|
||||
})
|
||||
|
||||
err := b.Bind(req, &user)
|
||||
err = b.Bind(req, &user)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "john", user.Name)
|
||||
|
@ -139,6 +160,38 @@ func Test_FormBinder_BindMultipart(t *testing.T) {
|
|||
require.Equal(t, "post1", user.Posts[0].Title)
|
||||
require.Equal(t, "post2", user.Posts[1].Title)
|
||||
require.Equal(t, "post3", user.Posts[2].Title)
|
||||
|
||||
require.NotNil(t, user.Avatar)
|
||||
require.Equal(t, "avatar.txt", user.Avatar.Filename)
|
||||
require.Equal(t, "application/octet-stream", user.Avatar.Header.Get("Content-Type"))
|
||||
|
||||
file, err := user.Avatar.Open()
|
||||
require.NoError(t, err)
|
||||
|
||||
content, err := io.ReadAll(file)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "avatar", string(content))
|
||||
|
||||
require.Len(t, user.Avatars, 2)
|
||||
require.Equal(t, "avatar1.txt", user.Avatars[0].Filename)
|
||||
require.Equal(t, "application/octet-stream", user.Avatars[0].Header.Get("Content-Type"))
|
||||
|
||||
file, err = user.Avatars[0].Open()
|
||||
require.NoError(t, err)
|
||||
|
||||
content, err = io.ReadAll(file)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "avatar1", string(content))
|
||||
|
||||
require.Equal(t, "avatar2.txt", user.Avatars[1].Filename)
|
||||
require.Equal(t, "application/octet-stream", user.Avatars[1].Header.Get("Content-Type"))
|
||||
|
||||
file, err = user.Avatars[1].Open()
|
||||
require.NoError(t, err)
|
||||
|
||||
content, err = io.ReadAll(file)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "avatar2", string(content))
|
||||
}
|
||||
|
||||
func Benchmark_FormBinder_BindMultipart(b *testing.B) {
|
||||
|
|
|
@ -3,6 +3,7 @@ package binder
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -69,7 +70,7 @@ func init() {
|
|||
}
|
||||
|
||||
// parse data into the map or struct
|
||||
func parse(aliasTag string, out any, data map[string][]string) error {
|
||||
func parse(aliasTag string, out any, data map[string][]string, files ...map[string][]*multipart.FileHeader) error {
|
||||
ptrVal := reflect.ValueOf(out)
|
||||
|
||||
// Get pointer value
|
||||
|
@ -83,11 +84,11 @@ func parse(aliasTag string, out any, data map[string][]string) error {
|
|||
}
|
||||
|
||||
// Parse into the struct
|
||||
return parseToStruct(aliasTag, out, data)
|
||||
return parseToStruct(aliasTag, out, data, files...)
|
||||
}
|
||||
|
||||
// Parse data into the struct with gorilla/schema
|
||||
func parseToStruct(aliasTag string, out any, data map[string][]string) error {
|
||||
func parseToStruct(aliasTag string, out any, data map[string][]string, files ...map[string][]*multipart.FileHeader) error {
|
||||
// Get decoder from pool
|
||||
schemaDecoder := decoderPoolMap[aliasTag].Get().(*schema.Decoder) //nolint:errcheck,forcetypeassert // not needed
|
||||
defer decoderPoolMap[aliasTag].Put(schemaDecoder)
|
||||
|
@ -95,7 +96,7 @@ func parseToStruct(aliasTag string, out any, data map[string][]string) error {
|
|||
// Set alias tag
|
||||
schemaDecoder.SetAliasTag(aliasTag)
|
||||
|
||||
if err := schemaDecoder.Decode(out, data); err != nil {
|
||||
if err := schemaDecoder.Decode(out, data, files...); err != nil {
|
||||
return fmt.Errorf("bind: %w", err)
|
||||
}
|
||||
|
||||
|
@ -250,7 +251,7 @@ func FilterFlags(content string) string {
|
|||
return content
|
||||
}
|
||||
|
||||
func formatBindData[T any](out any, data map[string][]string, key string, value T, enableSplitting, supportBracketNotation bool) error { //nolint:revive // it's okay
|
||||
func formatBindData[T, K any](out any, data map[string][]T, key string, value K, enableSplitting, supportBracketNotation bool) error { //nolint:revive // it's okay
|
||||
var err error
|
||||
if supportBracketNotation && strings.Contains(key, "[") {
|
||||
key, err = parseParamSquareBrackets(key)
|
||||
|
@ -261,10 +262,28 @@ func formatBindData[T any](out any, data map[string][]string, key string, value
|
|||
|
||||
switch v := any(value).(type) {
|
||||
case string:
|
||||
assignBindData(out, data, key, v, enableSplitting)
|
||||
dataMap, ok := any(data).(map[string][]string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unsupported value type: %T", value)
|
||||
}
|
||||
|
||||
assignBindData(out, dataMap, key, v, enableSplitting)
|
||||
case []string:
|
||||
dataMap, ok := any(data).(map[string][]string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unsupported value type: %T", value)
|
||||
}
|
||||
|
||||
for _, val := range v {
|
||||
assignBindData(out, data, key, val, enableSplitting)
|
||||
assignBindData(out, dataMap, key, val, enableSplitting)
|
||||
}
|
||||
case []*multipart.FileHeader:
|
||||
for _, val := range v {
|
||||
valT, ok := any(val).(T)
|
||||
if !ok {
|
||||
return fmt.Errorf("unsupported value type: %T", value)
|
||||
}
|
||||
data[key] = append(data[key], valT)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported value type: %T", value)
|
||||
|
|
|
@ -2,6 +2,7 @@ package binder
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"mime/multipart"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
|
@ -9,6 +10,8 @@ import (
|
|||
)
|
||||
|
||||
func Test_EqualFieldType(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var out int
|
||||
require.False(t, equalFieldType(&out, reflect.Int, "key"))
|
||||
|
||||
|
@ -47,6 +50,8 @@ func Test_EqualFieldType(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_ParseParamSquareBrackets(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
err error
|
||||
input string
|
||||
|
@ -101,6 +106,8 @@ func Test_ParseParamSquareBrackets(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result, err := parseParamSquareBrackets(tt.input)
|
||||
if tt.err != nil {
|
||||
require.Error(t, err)
|
||||
|
@ -114,6 +121,8 @@ func Test_ParseParamSquareBrackets(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_parseToMap(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inputMap := map[string][]string{
|
||||
"key1": {"value1", "value2"},
|
||||
"key2": {"value3"},
|
||||
|
@ -147,6 +156,8 @@ func Test_parseToMap(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_FilterFlags(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
|
@ -172,8 +183,163 @@ func Test_FilterFlags(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result := FilterFlags(tt.input)
|
||||
require.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatBindData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("string value with valid key", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
out := struct{}{}
|
||||
data := make(map[string][]string)
|
||||
err := formatBindData(out, data, "name", "John", false, false)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(data["name"]) != 1 || data["name"][0] != "John" {
|
||||
t.Fatalf("expected data[\"name\"] = [John], got %v", data["name"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unsupported value type", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
out := struct{}{}
|
||||
data := make(map[string][]string)
|
||||
err := formatBindData(out, data, "age", 30, false, false) // int is unsupported
|
||||
if err == nil {
|
||||
t.Fatal("expected an error, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("bracket notation parsing error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
out := struct{}{}
|
||||
data := make(map[string][]string)
|
||||
err := formatBindData(out, data, "invalid[", "value", false, true) // malformed bracket notation
|
||||
if err == nil {
|
||||
t.Fatal("expected an error, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handling multipart file headers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
out := struct{}{}
|
||||
data := make(map[string][]*multipart.FileHeader)
|
||||
files := []*multipart.FileHeader{
|
||||
{Filename: "file1.txt"},
|
||||
{Filename: "file2.txt"},
|
||||
}
|
||||
err := formatBindData(out, data, "files", files, false, false)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(data["files"]) != 2 {
|
||||
t.Fatalf("expected 2 files, got %d", len(data["files"]))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("type casting error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
out := struct{}{}
|
||||
data := map[string][]int{} // Incorrect type to force a casting error
|
||||
err := formatBindData(out, data, "key", "value", false, false)
|
||||
require.Equal(t, "unsupported value type: string", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestAssignBindData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("splitting enabled with comma", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
out := struct {
|
||||
Colors []string `query:"colors"`
|
||||
}{}
|
||||
data := make(map[string][]string)
|
||||
assignBindData(&out, data, "colors", "red,blue,green", true)
|
||||
require.Len(t, data["colors"], 3)
|
||||
})
|
||||
|
||||
t.Run("splitting disabled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var out []string
|
||||
data := make(map[string][]string)
|
||||
assignBindData(out, data, "color", "red,blue", false)
|
||||
require.Len(t, data["color"], 1)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_parseToStruct_MismatchedData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type User struct {
|
||||
Name string `query:"name"`
|
||||
Age int `query:"age"`
|
||||
}
|
||||
|
||||
data := map[string][]string{
|
||||
"name": {"John"},
|
||||
"age": {"invalidAge"},
|
||||
}
|
||||
|
||||
err := parseToStruct("query", &User{}, data)
|
||||
require.Error(t, err)
|
||||
require.EqualError(t, err, "bind: schema: error converting value for \"age\"")
|
||||
}
|
||||
|
||||
func Test_formatBindData_ErrorCases(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("unsupported value type int", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
out := struct{}{}
|
||||
data := make(map[string][]string)
|
||||
err := formatBindData(out, data, "age", 30, false, false) // int is unsupported
|
||||
require.Error(t, err)
|
||||
require.EqualError(t, err, "unsupported value type: int")
|
||||
})
|
||||
|
||||
t.Run("unsupported value type map", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
out := struct{}{}
|
||||
data := make(map[string][]string)
|
||||
err := formatBindData(out, data, "map", map[string]string{"key": "value"}, false, false) // map is unsupported
|
||||
require.Error(t, err)
|
||||
require.EqualError(t, err, "unsupported value type: map[string]string")
|
||||
})
|
||||
|
||||
t.Run("bracket notation parsing error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
out := struct{}{}
|
||||
data := make(map[string][]string)
|
||||
err := formatBindData(out, data, "invalid[", "value", false, true) // malformed bracket notation
|
||||
require.Error(t, err)
|
||||
require.EqualError(t, err, "unmatched brackets")
|
||||
})
|
||||
|
||||
t.Run("type casting error for []string", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
out := struct{}{}
|
||||
data := make(map[string][]string)
|
||||
err := formatBindData(out, data, "names", 123, false, false) // invalid type for []string
|
||||
require.Error(t, err)
|
||||
require.EqualError(t, err, "unsupported value type: int")
|
||||
})
|
||||
}
|
||||
|
|
|
@ -120,6 +120,38 @@ curl -X POST -H "Content-Type: application/x-www-form-urlencoded" --data "name=j
|
|||
curl -X POST -H "Content-Type: multipart/form-data" -F "name=john" -F "pass=doe" localhost:3000
|
||||
```
|
||||
|
||||
:::info
|
||||
If you need to bind multipart file, you can use `*multipart.FileHeader`, `*[]*multipart.FileHeader` or `[]*multipart.FileHeader` as a field type.
|
||||
:::
|
||||
|
||||
```go title="Example"
|
||||
type Person struct {
|
||||
Name string `form:"name"`
|
||||
Pass string `form:"pass"`
|
||||
Avatar *multipart.FileHeader `form:"avatar"`
|
||||
}
|
||||
|
||||
app.Post("/", func(c fiber.Ctx) error {
|
||||
p := new(Person)
|
||||
|
||||
if err := c.Bind().Form(p); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Println(p.Name) // john
|
||||
log.Println(p.Pass) // doe
|
||||
log.Println(p.Avatar.Filename) // file.txt
|
||||
|
||||
// ...
|
||||
})
|
||||
```
|
||||
|
||||
Run tests with the following `curl` command:
|
||||
|
||||
```bash
|
||||
curl -X POST -H "Content-Type: multipart/form-data" -F "name=john" -F "pass=doe" -F 'avatar=@filename' localhost:3000
|
||||
```
|
||||
|
||||
### JSON
|
||||
|
||||
Binds the request JSON body to a struct.
|
||||
|
|
|
@ -546,6 +546,7 @@ Fiber v3 introduces a new binding mechanism that simplifies the process of bindi
|
|||
- Unified binding from URL parameters, query parameters, headers, and request bodies.
|
||||
- Support for custom binders and constraints.
|
||||
- Improved error handling and validation.
|
||||
- Support multipart file binding for `*multipart.FileHeader`, `*[]*multipart.FileHeader`, and `[]*multipart.FileHeader` field types.
|
||||
|
||||
<details>
|
||||
<summary>Example</summary>
|
||||
|
|
Loading…
Reference in New Issue