👷 Improve test coverage

This commit is contained in:
Kiyon 2021-02-20 16:12:06 +08:00
parent 62d311133b
commit c477128e5b
2 changed files with 60 additions and 13 deletions

View File

@ -134,6 +134,7 @@ type Agent struct {
errs []error
formFiles []*FormFile
debugWriter io.Writer
mw multipartWriter
maxRedirectsCount int
boundary string
Name string
@ -142,8 +143,6 @@ type Agent struct {
parsed bool
}
var ErrorInvalidURI = fasthttp.ErrorInvalidURI
// Parse initializes URI and HostClient.
func (a *Agent) Parse() error {
if a.parsed {
@ -157,9 +156,6 @@ func (a *Agent) Parse() error {
}
uri := req.URI()
if uri == nil {
return ErrorInvalidURI
}
isTLS := false
scheme := uri.Scheme()
@ -567,27 +563,29 @@ func (a *Agent) Boundary(boundary string) *Agent {
// It is recommended obtaining args via AcquireArgs
// in performance-critical code.
func (a *Agent) MultipartForm(args *Args) *Agent {
mw := multipart.NewWriter(a.req.BodyWriter())
if a.mw == nil {
a.mw = multipart.NewWriter(a.req.BodyWriter())
}
if a.boundary != "" {
if err := mw.SetBoundary(a.boundary); err != nil {
if err := a.mw.SetBoundary(a.boundary); err != nil {
a.errs = append(a.errs, err)
return a
}
}
a.req.Header.SetMultipartFormBoundary(mw.Boundary())
a.req.Header.SetMultipartFormBoundary(a.mw.Boundary())
if args != nil {
args.VisitAll(func(key, value []byte) {
if err := mw.WriteField(getString(key), getString(value)); err != nil {
if err := a.mw.WriteField(getString(key), getString(value)); err != nil {
a.errs = append(a.errs, err)
}
})
}
for _, ff := range a.formFiles {
w, err := mw.CreateFormFile(ff.Fieldname, ff.Name)
w, err := a.mw.CreateFormFile(ff.Fieldname, ff.Name)
if err != nil {
a.errs = append(a.errs, err)
continue
@ -597,7 +595,7 @@ func (a *Agent) MultipartForm(args *Args) *Agent {
}
}
if err := mw.Close(); err != nil {
if err := a.mw.Close(); err != nil {
a.errs = append(a.errs, err)
}
@ -765,6 +763,7 @@ func (a *Agent) reset() {
a.args = nil
a.errs = a.errs[:0]
a.debugWriter = nil
a.mw = nil
a.reuse = false
a.parsed = false
a.maxRedirectsCount = 0
@ -934,3 +933,11 @@ var (
strHTTPS = []byte("https")
defaultUserAgent = "fiber"
)
type multipartWriter interface {
Boundary() string
SetBoundary(boundary string) error
CreateFormFile(fieldname, filename string) (io.Writer, error)
WriteField(fieldname, value string) error
Close() error
}

View File

@ -4,6 +4,8 @@ import (
"bytes"
"crypto/tls"
"encoding/base64"
"errors"
"io"
"io/ioutil"
"mime/multipart"
"net"
@ -584,7 +586,7 @@ func Test_Client_Agent_Form(t *testing.T) {
ReleaseArgs(args)
}
func Test_Client_Agent_Multipart(t *testing.T) {
func Test_Client_Agent_MultipartForm(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
@ -621,7 +623,25 @@ func Test_Client_Agent_Multipart(t *testing.T) {
ReleaseArgs(args)
}
func Test_Client_Agent_Multipart_SendFiles(t *testing.T) {
func Test_Client_Agent_MultipartForm_Errors(t *testing.T) {
t.Parallel()
a := AcquireAgent()
a.mw = &errorMultipartWriter{}
args := AcquireArgs()
args.Set("foo", "bar")
ff1 := &FormFile{"", "name1", []byte("content"), false}
ff2 := &FormFile{"", "name2", []byte("content"), false}
a.FileData(ff1, ff2).
MultipartForm(args)
utils.AssertEqual(t, 4, len(a.errs))
ReleaseArgs(args)
}
func Test_Client_Agent_MultipartForm_SendFiles(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
@ -983,3 +1003,23 @@ func testAgent(t *testing.T, handler Handler, wrapAgent func(agent *Agent), exce
type data struct {
Success bool `json:"success" xml:"success"`
}
type errorMultipartWriter struct {
count int
}
func (e *errorMultipartWriter) Boundary() string { return "myBoundary" }
func (e *errorMultipartWriter) SetBoundary(_ string) error { return nil }
func (e *errorMultipartWriter) CreateFormFile(_, _ string) (io.Writer, error) {
if e.count == 0 {
e.count++
return nil, errors.New("CreateFormFile error")
}
return errorWriter{}, nil
}
func (e *errorMultipartWriter) WriteField(_, _ string) error { return errors.New("WriteField error") }
func (e *errorMultipartWriter) Close() error { return errors.New("Close error") }
type errorWriter struct{}
func (errorWriter) Write(_ []byte) (int, error) { return 0, errors.New("Write error") }