diff --git a/client.go b/client.go index 07bb46b3..f8155576 100644 --- a/client.go +++ b/client.go @@ -134,6 +134,7 @@ type Agent struct { errs []error formFiles []*FormFile debugWriter io.Writer + mw multipartWriter maxRedirectsCount int boundary string Name string @@ -142,8 +143,6 @@ type Agent struct { parsed bool } -var ErrorInvalidURI = fasthttp.ErrorInvalidURI - // Parse initializes URI and HostClient. func (a *Agent) Parse() error { if a.parsed { @@ -157,9 +156,6 @@ func (a *Agent) Parse() error { } uri := req.URI() - if uri == nil { - return ErrorInvalidURI - } isTLS := false scheme := uri.Scheme() @@ -567,27 +563,29 @@ func (a *Agent) Boundary(boundary string) *Agent { // It is recommended obtaining args via AcquireArgs // in performance-critical code. func (a *Agent) MultipartForm(args *Args) *Agent { - mw := multipart.NewWriter(a.req.BodyWriter()) + if a.mw == nil { + a.mw = multipart.NewWriter(a.req.BodyWriter()) + } if a.boundary != "" { - if err := mw.SetBoundary(a.boundary); err != nil { + if err := a.mw.SetBoundary(a.boundary); err != nil { a.errs = append(a.errs, err) return a } } - a.req.Header.SetMultipartFormBoundary(mw.Boundary()) + a.req.Header.SetMultipartFormBoundary(a.mw.Boundary()) if args != nil { args.VisitAll(func(key, value []byte) { - if err := mw.WriteField(getString(key), getString(value)); err != nil { + if err := a.mw.WriteField(getString(key), getString(value)); err != nil { a.errs = append(a.errs, err) } }) } for _, ff := range a.formFiles { - w, err := mw.CreateFormFile(ff.Fieldname, ff.Name) + w, err := a.mw.CreateFormFile(ff.Fieldname, ff.Name) if err != nil { a.errs = append(a.errs, err) continue @@ -597,7 +595,7 @@ func (a *Agent) MultipartForm(args *Args) *Agent { } } - if err := mw.Close(); err != nil { + if err := a.mw.Close(); err != nil { a.errs = append(a.errs, err) } @@ -765,6 +763,7 @@ func (a *Agent) reset() { a.args = nil a.errs = a.errs[:0] a.debugWriter = nil + a.mw = nil a.reuse = false a.parsed = false a.maxRedirectsCount = 0 @@ -934,3 +933,11 @@ var ( strHTTPS = []byte("https") defaultUserAgent = "fiber" ) + +type multipartWriter interface { + Boundary() string + SetBoundary(boundary string) error + CreateFormFile(fieldname, filename string) (io.Writer, error) + WriteField(fieldname, value string) error + Close() error +} diff --git a/client_test.go b/client_test.go index c28d2954..424d69bc 100644 --- a/client_test.go +++ b/client_test.go @@ -4,6 +4,8 @@ import ( "bytes" "crypto/tls" "encoding/base64" + "errors" + "io" "io/ioutil" "mime/multipart" "net" @@ -584,7 +586,7 @@ func Test_Client_Agent_Form(t *testing.T) { ReleaseArgs(args) } -func Test_Client_Agent_Multipart(t *testing.T) { +func Test_Client_Agent_MultipartForm(t *testing.T) { t.Parallel() ln := fasthttputil.NewInmemoryListener() @@ -621,7 +623,25 @@ func Test_Client_Agent_Multipart(t *testing.T) { ReleaseArgs(args) } -func Test_Client_Agent_Multipart_SendFiles(t *testing.T) { +func Test_Client_Agent_MultipartForm_Errors(t *testing.T) { + t.Parallel() + + a := AcquireAgent() + a.mw = &errorMultipartWriter{} + + args := AcquireArgs() + args.Set("foo", "bar") + + ff1 := &FormFile{"", "name1", []byte("content"), false} + ff2 := &FormFile{"", "name2", []byte("content"), false} + a.FileData(ff1, ff2). + MultipartForm(args) + + utils.AssertEqual(t, 4, len(a.errs)) + ReleaseArgs(args) +} + +func Test_Client_Agent_MultipartForm_SendFiles(t *testing.T) { t.Parallel() ln := fasthttputil.NewInmemoryListener() @@ -983,3 +1003,23 @@ func testAgent(t *testing.T, handler Handler, wrapAgent func(agent *Agent), exce type data struct { Success bool `json:"success" xml:"success"` } + +type errorMultipartWriter struct { + count int +} + +func (e *errorMultipartWriter) Boundary() string { return "myBoundary" } +func (e *errorMultipartWriter) SetBoundary(_ string) error { return nil } +func (e *errorMultipartWriter) CreateFormFile(_, _ string) (io.Writer, error) { + if e.count == 0 { + e.count++ + return nil, errors.New("CreateFormFile error") + } + return errorWriter{}, nil +} +func (e *errorMultipartWriter) WriteField(_, _ string) error { return errors.New("WriteField error") } +func (e *errorMultipartWriter) Close() error { return errors.New("Close error") } + +type errorWriter struct{} + +func (errorWriter) Write(_ []byte) (int, error) { return 0, errors.New("Write error") }