diff --git a/binder/form.go b/binder/form.go index a8f5b852..fab28034 100644 --- a/binder/form.go +++ b/binder/form.go @@ -1,6 +1,8 @@ package binder import ( + "mime/multipart" + "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) @@ -59,7 +61,15 @@ func (b *FormBinding) bindMultipart(req *fasthttp.Request, out any) error { } } - return parse(b.Name(), out, data) + files := make(map[string][]*multipart.FileHeader) + for key, values := range multipartForm.File { + err = formatBindData(out, files, key, values, b.EnableSplitting, true) + if err != nil { + return err + } + } + + return parse(b.Name(), out, data, files) } // Reset resets the FormBinding binder. diff --git a/binder/form_test.go b/binder/form_test.go index 55023cb3..d961f873 100644 --- a/binder/form_test.go +++ b/binder/form_test.go @@ -2,6 +2,7 @@ package binder import ( "bytes" + "io" "mime/multipart" "testing" @@ -98,10 +99,12 @@ func Test_FormBinder_BindMultipart(t *testing.T) { } type User struct { - Name string `form:"name"` - Names []string `form:"names"` - Posts []Post `form:"posts"` - Age int `form:"age"` + Avatar *multipart.FileHeader `form:"avatar"` + Name string `form:"name"` + Names []string `form:"names"` + Posts []Post `form:"posts"` + Avatars []*multipart.FileHeader `form:"avatars"` + Age int `form:"age"` } var user User @@ -118,6 +121,24 @@ func Test_FormBinder_BindMultipart(t *testing.T) { require.NoError(t, mw.WriteField("posts[1][title]", "post2")) require.NoError(t, mw.WriteField("posts[2][title]", "post3")) + writer, err := mw.CreateFormFile("avatar", "avatar.txt") + require.NoError(t, err) + + _, err = writer.Write([]byte("avatar")) + require.NoError(t, err) + + writer, err = mw.CreateFormFile("avatars", "avatar1.txt") + require.NoError(t, err) + + _, err = writer.Write([]byte("avatar1")) + require.NoError(t, err) + + writer, err = mw.CreateFormFile("avatars", "avatar2.txt") + require.NoError(t, err) + + _, err = writer.Write([]byte("avatar2")) + require.NoError(t, err) + require.NoError(t, mw.Close()) req.Header.SetContentType(mw.FormDataContentType()) @@ -127,7 +148,7 @@ func Test_FormBinder_BindMultipart(t *testing.T) { fasthttp.ReleaseRequest(req) }) - err := b.Bind(req, &user) + err = b.Bind(req, &user) require.NoError(t, err) require.Equal(t, "john", user.Name) @@ -139,6 +160,38 @@ func Test_FormBinder_BindMultipart(t *testing.T) { require.Equal(t, "post1", user.Posts[0].Title) require.Equal(t, "post2", user.Posts[1].Title) require.Equal(t, "post3", user.Posts[2].Title) + + require.NotNil(t, user.Avatar) + require.Equal(t, "avatar.txt", user.Avatar.Filename) + require.Equal(t, "application/octet-stream", user.Avatar.Header.Get("Content-Type")) + + file, err := user.Avatar.Open() + require.NoError(t, err) + + content, err := io.ReadAll(file) + require.NoError(t, err) + require.Equal(t, "avatar", string(content)) + + require.Len(t, user.Avatars, 2) + require.Equal(t, "avatar1.txt", user.Avatars[0].Filename) + require.Equal(t, "application/octet-stream", user.Avatars[0].Header.Get("Content-Type")) + + file, err = user.Avatars[0].Open() + require.NoError(t, err) + + content, err = io.ReadAll(file) + require.NoError(t, err) + require.Equal(t, "avatar1", string(content)) + + require.Equal(t, "avatar2.txt", user.Avatars[1].Filename) + require.Equal(t, "application/octet-stream", user.Avatars[1].Header.Get("Content-Type")) + + file, err = user.Avatars[1].Open() + require.NoError(t, err) + + content, err = io.ReadAll(file) + require.NoError(t, err) + require.Equal(t, "avatar2", string(content)) } func Benchmark_FormBinder_BindMultipart(b *testing.B) { diff --git a/binder/mapping.go b/binder/mapping.go index 70cb9cbc..bc95d028 100644 --- a/binder/mapping.go +++ b/binder/mapping.go @@ -3,6 +3,7 @@ package binder import ( "errors" "fmt" + "mime/multipart" "reflect" "strings" "sync" @@ -69,7 +70,7 @@ func init() { } // parse data into the map or struct -func parse(aliasTag string, out any, data map[string][]string) error { +func parse(aliasTag string, out any, data map[string][]string, files ...map[string][]*multipart.FileHeader) error { ptrVal := reflect.ValueOf(out) // Get pointer value @@ -83,11 +84,11 @@ func parse(aliasTag string, out any, data map[string][]string) error { } // Parse into the struct - return parseToStruct(aliasTag, out, data) + return parseToStruct(aliasTag, out, data, files...) } // Parse data into the struct with gorilla/schema -func parseToStruct(aliasTag string, out any, data map[string][]string) error { +func parseToStruct(aliasTag string, out any, data map[string][]string, files ...map[string][]*multipart.FileHeader) error { // Get decoder from pool schemaDecoder := decoderPoolMap[aliasTag].Get().(*schema.Decoder) //nolint:errcheck,forcetypeassert // not needed defer decoderPoolMap[aliasTag].Put(schemaDecoder) @@ -95,7 +96,7 @@ func parseToStruct(aliasTag string, out any, data map[string][]string) error { // Set alias tag schemaDecoder.SetAliasTag(aliasTag) - if err := schemaDecoder.Decode(out, data); err != nil { + if err := schemaDecoder.Decode(out, data, files...); err != nil { return fmt.Errorf("bind: %w", err) } @@ -250,7 +251,7 @@ func FilterFlags(content string) string { return content } -func formatBindData[T any](out any, data map[string][]string, key string, value T, enableSplitting, supportBracketNotation bool) error { //nolint:revive // it's okay +func formatBindData[T, K any](out any, data map[string][]T, key string, value K, enableSplitting, supportBracketNotation bool) error { //nolint:revive // it's okay var err error if supportBracketNotation && strings.Contains(key, "[") { key, err = parseParamSquareBrackets(key) @@ -261,10 +262,28 @@ func formatBindData[T any](out any, data map[string][]string, key string, value switch v := any(value).(type) { case string: - assignBindData(out, data, key, v, enableSplitting) + dataMap, ok := any(data).(map[string][]string) + if !ok { + return fmt.Errorf("unsupported value type: %T", value) + } + + assignBindData(out, dataMap, key, v, enableSplitting) case []string: + dataMap, ok := any(data).(map[string][]string) + if !ok { + return fmt.Errorf("unsupported value type: %T", value) + } + for _, val := range v { - assignBindData(out, data, key, val, enableSplitting) + assignBindData(out, dataMap, key, val, enableSplitting) + } + case []*multipart.FileHeader: + for _, val := range v { + valT, ok := any(val).(T) + if !ok { + return fmt.Errorf("unsupported value type: %T", value) + } + data[key] = append(data[key], valT) } default: return fmt.Errorf("unsupported value type: %T", value) diff --git a/binder/mapping_test.go b/binder/mapping_test.go index 75cdc783..9c7b92ee 100644 --- a/binder/mapping_test.go +++ b/binder/mapping_test.go @@ -2,6 +2,7 @@ package binder import ( "errors" + "mime/multipart" "reflect" "testing" @@ -9,6 +10,8 @@ import ( ) func Test_EqualFieldType(t *testing.T) { + t.Parallel() + var out int require.False(t, equalFieldType(&out, reflect.Int, "key")) @@ -47,6 +50,8 @@ func Test_EqualFieldType(t *testing.T) { } func Test_ParseParamSquareBrackets(t *testing.T) { + t.Parallel() + tests := []struct { err error input string @@ -101,6 +106,8 @@ func Test_ParseParamSquareBrackets(t *testing.T) { for _, tt := range tests { t.Run(tt.input, func(t *testing.T) { + t.Parallel() + result, err := parseParamSquareBrackets(tt.input) if tt.err != nil { require.Error(t, err) @@ -114,6 +121,8 @@ func Test_ParseParamSquareBrackets(t *testing.T) { } func Test_parseToMap(t *testing.T) { + t.Parallel() + inputMap := map[string][]string{ "key1": {"value1", "value2"}, "key2": {"value3"}, @@ -147,6 +156,8 @@ func Test_parseToMap(t *testing.T) { } func Test_FilterFlags(t *testing.T) { + t.Parallel() + tests := []struct { input string expected string @@ -172,8 +183,163 @@ func Test_FilterFlags(t *testing.T) { for _, tt := range tests { t.Run(tt.input, func(t *testing.T) { + t.Parallel() + result := FilterFlags(tt.input) require.Equal(t, tt.expected, result) }) } } + +func TestFormatBindData(t *testing.T) { + t.Parallel() + + t.Run("string value with valid key", func(t *testing.T) { + t.Parallel() + + out := struct{}{} + data := make(map[string][]string) + err := formatBindData(out, data, "name", "John", false, false) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(data["name"]) != 1 || data["name"][0] != "John" { + t.Fatalf("expected data[\"name\"] = [John], got %v", data["name"]) + } + }) + + t.Run("unsupported value type", func(t *testing.T) { + t.Parallel() + + out := struct{}{} + data := make(map[string][]string) + err := formatBindData(out, data, "age", 30, false, false) // int is unsupported + if err == nil { + t.Fatal("expected an error, got nil") + } + }) + + t.Run("bracket notation parsing error", func(t *testing.T) { + t.Parallel() + + out := struct{}{} + data := make(map[string][]string) + err := formatBindData(out, data, "invalid[", "value", false, true) // malformed bracket notation + if err == nil { + t.Fatal("expected an error, got nil") + } + }) + + t.Run("handling multipart file headers", func(t *testing.T) { + t.Parallel() + + out := struct{}{} + data := make(map[string][]*multipart.FileHeader) + files := []*multipart.FileHeader{ + {Filename: "file1.txt"}, + {Filename: "file2.txt"}, + } + err := formatBindData(out, data, "files", files, false, false) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(data["files"]) != 2 { + t.Fatalf("expected 2 files, got %d", len(data["files"])) + } + }) + + t.Run("type casting error", func(t *testing.T) { + t.Parallel() + + out := struct{}{} + data := map[string][]int{} // Incorrect type to force a casting error + err := formatBindData(out, data, "key", "value", false, false) + require.Equal(t, "unsupported value type: string", err.Error()) + }) +} + +func TestAssignBindData(t *testing.T) { + t.Parallel() + + t.Run("splitting enabled with comma", func(t *testing.T) { + t.Parallel() + + out := struct { + Colors []string `query:"colors"` + }{} + data := make(map[string][]string) + assignBindData(&out, data, "colors", "red,blue,green", true) + require.Len(t, data["colors"], 3) + }) + + t.Run("splitting disabled", func(t *testing.T) { + t.Parallel() + + var out []string + data := make(map[string][]string) + assignBindData(out, data, "color", "red,blue", false) + require.Len(t, data["color"], 1) + }) +} + +func Test_parseToStruct_MismatchedData(t *testing.T) { + t.Parallel() + + type User struct { + Name string `query:"name"` + Age int `query:"age"` + } + + data := map[string][]string{ + "name": {"John"}, + "age": {"invalidAge"}, + } + + err := parseToStruct("query", &User{}, data) + require.Error(t, err) + require.EqualError(t, err, "bind: schema: error converting value for \"age\"") +} + +func Test_formatBindData_ErrorCases(t *testing.T) { + t.Parallel() + + t.Run("unsupported value type int", func(t *testing.T) { + t.Parallel() + + out := struct{}{} + data := make(map[string][]string) + err := formatBindData(out, data, "age", 30, false, false) // int is unsupported + require.Error(t, err) + require.EqualError(t, err, "unsupported value type: int") + }) + + t.Run("unsupported value type map", func(t *testing.T) { + t.Parallel() + + out := struct{}{} + data := make(map[string][]string) + err := formatBindData(out, data, "map", map[string]string{"key": "value"}, false, false) // map is unsupported + require.Error(t, err) + require.EqualError(t, err, "unsupported value type: map[string]string") + }) + + t.Run("bracket notation parsing error", func(t *testing.T) { + t.Parallel() + + out := struct{}{} + data := make(map[string][]string) + err := formatBindData(out, data, "invalid[", "value", false, true) // malformed bracket notation + require.Error(t, err) + require.EqualError(t, err, "unmatched brackets") + }) + + t.Run("type casting error for []string", func(t *testing.T) { + t.Parallel() + + out := struct{}{} + data := make(map[string][]string) + err := formatBindData(out, data, "names", 123, false, false) // invalid type for []string + require.Error(t, err) + require.EqualError(t, err, "unsupported value type: int") + }) +} diff --git a/docs/api/bind.md b/docs/api/bind.md index d2b33631..eaad6305 100644 --- a/docs/api/bind.md +++ b/docs/api/bind.md @@ -120,6 +120,38 @@ curl -X POST -H "Content-Type: application/x-www-form-urlencoded" --data "name=j curl -X POST -H "Content-Type: multipart/form-data" -F "name=john" -F "pass=doe" localhost:3000 ``` +:::info +If you need to bind multipart file, you can use `*multipart.FileHeader`, `*[]*multipart.FileHeader` or `[]*multipart.FileHeader` as a field type. +::: + +```go title="Example" +type Person struct { + Name string `form:"name"` + Pass string `form:"pass"` + Avatar *multipart.FileHeader `form:"avatar"` +} + +app.Post("/", func(c fiber.Ctx) error { + p := new(Person) + + if err := c.Bind().Form(p); err != nil { + return err + } + + log.Println(p.Name) // john + log.Println(p.Pass) // doe + log.Println(p.Avatar.Filename) // file.txt + + // ... +}) +``` + +Run tests with the following `curl` command: + +```bash +curl -X POST -H "Content-Type: multipart/form-data" -F "name=john" -F "pass=doe" -F 'avatar=@filename' localhost:3000 +``` + ### JSON Binds the request JSON body to a struct. diff --git a/docs/whats_new.md b/docs/whats_new.md index 7f0b6322..4185a7e3 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -546,6 +546,7 @@ Fiber v3 introduces a new binding mechanism that simplifies the process of bindi - Unified binding from URL parameters, query parameters, headers, and request bodies. - Support for custom binders and constraints. - Improved error handling and validation. +- Support multipart file binding for `*multipart.FileHeader`, `*[]*multipart.FileHeader`, and `[]*multipart.FileHeader` field types.
Example