diff --git a/client.go b/client.go new file mode 100644 index 00000000..b5c36943 --- /dev/null +++ b/client.go @@ -0,0 +1,985 @@ +package fiber + +import ( + "bytes" + "crypto/tls" + "encoding/xml" + "fmt" + "io" + "io/ioutil" + "mime/multipart" + "net" + "os" + "path/filepath" + "strconv" + "strings" + "sync" + "time" + + "github.com/gofiber/fiber/v2/utils" + + "github.com/gofiber/fiber/v2/internal/encoding/json" + "github.com/valyala/fasthttp" +) + +// Request represents HTTP request. +// +// It is forbidden copying Request instances. Create new instances +// and use CopyTo instead. +// +// Request instance MUST NOT be used from concurrently running goroutines. +// Copy from fasthttp +type Request = fasthttp.Request + +// Response represents HTTP response. +// +// It is forbidden copying Response instances. Create new instances +// and use CopyTo instead. +// +// Response instance MUST NOT be used from concurrently running goroutines. +// Copy from fasthttp +type Response = fasthttp.Response + +// Args represents query arguments. +// +// It is forbidden copying Args instances. Create new instances instead +// and use CopyTo(). +// +// Args instance MUST NOT be used from concurrently running goroutines. +// Copy from fasthttp +type Args = fasthttp.Args + +var defaultClient Client + +// Client implements http client. +// +// It is safe calling Client methods from concurrently running goroutines. +type Client struct { + // UserAgent is used in User-Agent request header. + UserAgent string + + // NoDefaultUserAgentHeader when set to true, causes the default + // User-Agent header to be excluded from the Request. + NoDefaultUserAgentHeader bool + + // When set by an external client of Fiber it will use the provided implementation of a + // JSONMarshal + // + // Allowing for flexibility in using another json library for encoding + JSONEncoder utils.JSONMarshal + + // When set by an external client of Fiber it will use the provided implementation of a + // JSONUnmarshal + // + // Allowing for flexibility in using another json library for decoding + JSONDecoder utils.JSONUnmarshal +} + +// Get returns a agent with http method GET. +func Get(url string) *Agent { return defaultClient.Get(url) } + +// Get returns a agent with http method GET. +func (c *Client) Get(url string) *Agent { + return c.createAgent(MethodGet, url) +} + +// Head returns a agent with http method HEAD. +func Head(url string) *Agent { return defaultClient.Head(url) } + +// Head returns a agent with http method GET. +func (c *Client) Head(url string) *Agent { + return c.createAgent(MethodHead, url) +} + +// Post sends POST request to the given url. +func Post(url string) *Agent { return defaultClient.Post(url) } + +// Post sends POST request to the given url. +func (c *Client) Post(url string) *Agent { + return c.createAgent(MethodPost, url) +} + +// Put sends PUT request to the given url. +func Put(url string) *Agent { return defaultClient.Put(url) } + +// Put sends PUT request to the given url. +func (c *Client) Put(url string) *Agent { + return c.createAgent(MethodPut, url) +} + +// Patch sends PATCH request to the given url. +func Patch(url string) *Agent { return defaultClient.Patch(url) } + +// Patch sends PATCH request to the given url. +func (c *Client) Patch(url string) *Agent { + return c.createAgent(MethodPatch, url) +} + +// Delete sends DELETE request to the given url. +func Delete(url string) *Agent { return defaultClient.Delete(url) } + +// Delete sends DELETE request to the given url. +func (c *Client) Delete(url string) *Agent { + return c.createAgent(MethodDelete, url) +} + +func (c *Client) createAgent(method, url string) *Agent { + a := AcquireAgent() + a.req.Header.SetMethod(method) + a.req.SetRequestURI(url) + + a.Name = c.UserAgent + a.NoDefaultUserAgentHeader = c.NoDefaultUserAgentHeader + a.jsonDecoder = c.JSONDecoder + a.jsonEncoder = c.JSONEncoder + + if err := a.Parse(); err != nil { + a.errs = append(a.errs, err) + } + + return a +} + +// Agent is an object storing all request data for client. +// Agent instance MUST NOT be used from concurrently running goroutines. +type Agent struct { + // Name is used in User-Agent request header. + Name string + + // NoDefaultUserAgentHeader when set to true, causes the default + // User-Agent header to be excluded from the Request. + NoDefaultUserAgentHeader bool + + // HostClient is an embedded fasthttp HostClient + *fasthttp.HostClient + + req *Request + resp *Response + dest []byte + args *Args + timeout time.Duration + errs []error + formFiles []*FormFile + debugWriter io.Writer + mw multipartWriter + jsonEncoder utils.JSONMarshal + jsonDecoder utils.JSONUnmarshal + maxRedirectsCount int + boundary string + reuse bool + parsed bool +} + +// Parse initializes URI and HostClient. +func (a *Agent) Parse() error { + if a.parsed { + return nil + } + a.parsed = true + + uri := a.req.URI() + + isTLS := false + scheme := uri.Scheme() + if bytes.Equal(scheme, strHTTPS) { + isTLS = true + } else if !bytes.Equal(scheme, strHTTP) { + return fmt.Errorf("unsupported protocol %q. http and https are supported", scheme) + } + + name := a.Name + if name == "" && !a.NoDefaultUserAgentHeader { + name = defaultUserAgent + } + + a.HostClient = &fasthttp.HostClient{ + Addr: addMissingPort(string(uri.Host()), isTLS), + Name: name, + NoDefaultUserAgentHeader: a.NoDefaultUserAgentHeader, + IsTLS: isTLS, + } + + return nil +} + +func addMissingPort(addr string, isTLS bool) string { + n := strings.Index(addr, ":") + if n >= 0 { + return addr + } + port := 80 + if isTLS { + port = 443 + } + return net.JoinHostPort(addr, strconv.Itoa(port)) +} + +/************************** Header Setting **************************/ + +// Set sets the given 'key: value' header. +// +// Use Add for setting multiple header values under the same key. +func (a *Agent) Set(k, v string) *Agent { + a.req.Header.Set(k, v) + + return a +} + +// SetBytesK sets the given 'key: value' header. +// +// Use AddBytesK for setting multiple header values under the same key. +func (a *Agent) SetBytesK(k []byte, v string) *Agent { + a.req.Header.SetBytesK(k, v) + + return a +} + +// SetBytesV sets the given 'key: value' header. +// +// Use AddBytesV for setting multiple header values under the same key. +func (a *Agent) SetBytesV(k string, v []byte) *Agent { + a.req.Header.SetBytesV(k, v) + + return a +} + +// SetBytesKV sets the given 'key: value' header. +// +// Use AddBytesKV for setting multiple header values under the same key. +func (a *Agent) SetBytesKV(k []byte, v []byte) *Agent { + a.req.Header.SetBytesKV(k, v) + + return a +} + +// Add adds the given 'key: value' header. +// +// Multiple headers with the same key may be added with this function. +// Use Set for setting a single header for the given key. +func (a *Agent) Add(k, v string) *Agent { + a.req.Header.Add(k, v) + + return a +} + +// AddBytesK adds the given 'key: value' header. +// +// Multiple headers with the same key may be added with this function. +// Use SetBytesK for setting a single header for the given key. +func (a *Agent) AddBytesK(k []byte, v string) *Agent { + a.req.Header.AddBytesK(k, v) + + return a +} + +// AddBytesV adds the given 'key: value' header. +// +// Multiple headers with the same key may be added with this function. +// Use SetBytesV for setting a single header for the given key. +func (a *Agent) AddBytesV(k string, v []byte) *Agent { + a.req.Header.AddBytesV(k, v) + + return a +} + +// AddBytesKV adds the given 'key: value' header. +// +// Multiple headers with the same key may be added with this function. +// Use SetBytesKV for setting a single header for the given key. +func (a *Agent) AddBytesKV(k []byte, v []byte) *Agent { + a.req.Header.AddBytesKV(k, v) + + return a +} + +// ConnectionClose sets 'Connection: close' header. +func (a *Agent) ConnectionClose() *Agent { + a.req.Header.SetConnectionClose() + + return a +} + +// UserAgent sets User-Agent header value. +func (a *Agent) UserAgent(userAgent string) *Agent { + a.req.Header.SetUserAgent(userAgent) + + return a +} + +// UserAgentBytes sets User-Agent header value. +func (a *Agent) UserAgentBytes(userAgent []byte) *Agent { + a.req.Header.SetUserAgentBytes(userAgent) + + return a +} + +// Cookie sets one 'key: value' cookie. +func (a *Agent) Cookie(key, value string) *Agent { + a.req.Header.SetCookie(key, value) + + return a +} + +// CookieBytesK sets one 'key: value' cookie. +func (a *Agent) CookieBytesK(key []byte, value string) *Agent { + a.req.Header.SetCookieBytesK(key, value) + + return a +} + +// CookieBytesKV sets one 'key: value' cookie. +func (a *Agent) CookieBytesKV(key, value []byte) *Agent { + a.req.Header.SetCookieBytesKV(key, value) + + return a +} + +// Cookies sets multiple 'key: value' cookies. +func (a *Agent) Cookies(kv ...string) *Agent { + for i := 1; i < len(kv); i += 2 { + a.req.Header.SetCookie(kv[i-1], kv[i]) + } + + return a +} + +// CookiesBytesKV sets multiple 'key: value' cookies. +func (a *Agent) CookiesBytesKV(kv ...[]byte) *Agent { + for i := 1; i < len(kv); i += 2 { + a.req.Header.SetCookieBytesKV(kv[i-1], kv[i]) + } + + return a +} + +// Referer sets Referer header value. +func (a *Agent) Referer(referer string) *Agent { + a.req.Header.SetReferer(referer) + + return a +} + +// RefererBytes sets Referer header value. +func (a *Agent) RefererBytes(referer []byte) *Agent { + a.req.Header.SetRefererBytes(referer) + + return a +} + +// ContentType sets Content-Type header value. +func (a *Agent) ContentType(contentType string) *Agent { + a.req.Header.SetContentType(contentType) + + return a +} + +// ContentTypeBytes sets Content-Type header value. +func (a *Agent) ContentTypeBytes(contentType []byte) *Agent { + a.req.Header.SetContentTypeBytes(contentType) + + return a +} + +/************************** End Header Setting **************************/ + +/************************** URI Setting **************************/ + +// Host sets host for the uri. +func (a *Agent) Host(host string) *Agent { + a.req.URI().SetHost(host) + + return a +} + +// HostBytes sets host for the URI. +func (a *Agent) HostBytes(host []byte) *Agent { + a.req.URI().SetHostBytes(host) + + return a +} + +// QueryString sets URI query string. +func (a *Agent) QueryString(queryString string) *Agent { + a.req.URI().SetQueryString(queryString) + + return a +} + +// QueryStringBytes sets URI query string. +func (a *Agent) QueryStringBytes(queryString []byte) *Agent { + a.req.URI().SetQueryStringBytes(queryString) + + return a +} + +// BasicAuth sets URI username and password. +func (a *Agent) BasicAuth(username, password string) *Agent { + a.req.URI().SetUsername(username) + a.req.URI().SetPassword(password) + + return a +} + +// BasicAuthBytes sets URI username and password. +func (a *Agent) BasicAuthBytes(username, password []byte) *Agent { + a.req.URI().SetUsernameBytes(username) + a.req.URI().SetPasswordBytes(password) + + return a +} + +/************************** End URI Setting **************************/ + +/************************** Request Setting **************************/ + +// BodyString sets request body. +func (a *Agent) BodyString(bodyString string) *Agent { + a.req.SetBodyString(bodyString) + + return a +} + +// Body sets request body. +func (a *Agent) Body(body []byte) *Agent { + a.req.SetBody(body) + + return a +} + +// BodyStream sets request body stream and, optionally body size. +// +// If bodySize is >= 0, then the bodyStream must provide exactly bodySize bytes +// before returning io.EOF. +// +// If bodySize < 0, then bodyStream is read until io.EOF. +// +// bodyStream.Close() is called after finishing reading all body data +// if it implements io.Closer. +// +// Note that GET and HEAD requests cannot have body. +func (a *Agent) BodyStream(bodyStream io.Reader, bodySize int) *Agent { + a.req.SetBodyStream(bodyStream, bodySize) + + return a +} + +// JSON sends a JSON request. +func (a *Agent) JSON(v interface{}) *Agent { + if a.jsonEncoder == nil { + a.jsonEncoder = json.Marshal + } + + a.req.Header.SetContentType(MIMEApplicationJSON) + + if body, err := a.jsonEncoder(v); err != nil { + a.errs = append(a.errs, err) + } else { + a.req.SetBody(body) + } + + return a +} + +// XML sends a XML request. +func (a *Agent) XML(v interface{}) *Agent { + a.req.Header.SetContentType(MIMEApplicationXML) + + if body, err := xml.Marshal(v); err != nil { + a.errs = append(a.errs, err) + } else { + a.req.SetBody(body) + } + + return a +} + +// Form sends form request with body if args is non-nil. +// +// It is recommended obtaining args via AcquireArgs and release it +// manually in performance-critical code. +func (a *Agent) Form(args *Args) *Agent { + a.req.Header.SetContentType(MIMEApplicationForm) + + if args != nil { + a.req.SetBody(args.QueryString()) + } + + return a +} + +// FormFile represents multipart form file +type FormFile struct { + // Fieldname is form file's field name + Fieldname string + // Name is form file's name + Name string + // Content is form file's content + Content []byte + // autoRelease indicates if returns the object + // acquired via AcquireFormFile to the pool. + autoRelease bool +} + +// FileData appends files for multipart form request. +// +// It is recommended obtaining formFile via AcquireFormFile and release it +// manually in performance-critical code. +func (a *Agent) FileData(formFiles ...*FormFile) *Agent { + a.formFiles = append(a.formFiles, formFiles...) + + return a +} + +// SendFile reads file and appends it to multipart form request. +func (a *Agent) SendFile(filename string, fieldname ...string) *Agent { + content, err := ioutil.ReadFile(filepath.Clean(filename)) + if err != nil { + a.errs = append(a.errs, err) + return a + } + + ff := AcquireFormFile() + if len(fieldname) > 0 && fieldname[0] != "" { + ff.Fieldname = fieldname[0] + } else { + ff.Fieldname = "file" + strconv.Itoa(len(a.formFiles)+1) + } + ff.Name = filepath.Base(filename) + ff.Content = append(ff.Content, content...) + ff.autoRelease = true + + a.formFiles = append(a.formFiles, ff) + + return a +} + +// SendFiles reads files and appends them to multipart form request. +// +// Examples: +// SendFile("/path/to/file1", "fieldname1", "/path/to/file2") +func (a *Agent) SendFiles(filenamesAndFieldnames ...string) *Agent { + pairs := len(filenamesAndFieldnames) + if pairs&1 == 1 { + filenamesAndFieldnames = append(filenamesAndFieldnames, "") + } + + for i := 0; i < pairs; i += 2 { + a.SendFile(filenamesAndFieldnames[i], filenamesAndFieldnames[i+1]) + } + + return a +} + +// Boundary sets boundary for multipart form request. +func (a *Agent) Boundary(boundary string) *Agent { + a.boundary = boundary + + return a +} + +// MultipartForm sends multipart form request with k-v and files. +// +// It is recommended obtaining args via AcquireArgs and release it +// manually in performance-critical code. +func (a *Agent) MultipartForm(args *Args) *Agent { + if a.mw == nil { + a.mw = multipart.NewWriter(a.req.BodyWriter()) + } + + if a.boundary != "" { + if err := a.mw.SetBoundary(a.boundary); err != nil { + a.errs = append(a.errs, err) + return a + } + } + + a.req.Header.SetMultipartFormBoundary(a.mw.Boundary()) + + if args != nil { + args.VisitAll(func(key, value []byte) { + if err := a.mw.WriteField(getString(key), getString(value)); err != nil { + a.errs = append(a.errs, err) + } + }) + } + + for _, ff := range a.formFiles { + w, err := a.mw.CreateFormFile(ff.Fieldname, ff.Name) + if err != nil { + a.errs = append(a.errs, err) + continue + } + if _, err = w.Write(ff.Content); err != nil { + a.errs = append(a.errs, err) + } + } + + if err := a.mw.Close(); err != nil { + a.errs = append(a.errs, err) + } + + return a +} + +/************************** End Request Setting **************************/ + +/************************** Agent Setting **************************/ + +// Debug mode enables logging request and response detail +func (a *Agent) Debug(w ...io.Writer) *Agent { + a.debugWriter = os.Stdout + if len(w) > 0 { + a.debugWriter = w[0] + } + + return a +} + +// Timeout sets request timeout duration. +func (a *Agent) Timeout(timeout time.Duration) *Agent { + a.timeout = timeout + + return a +} + +// Reuse enables the Agent instance to be used again after one request. +// +// If agent is reusable, then it should be released manually when it is no +// longer used. +func (a *Agent) Reuse() *Agent { + a.reuse = true + + return a +} + +// InsecureSkipVerify controls whether the Agent verifies the server +// certificate chain and host name. +func (a *Agent) InsecureSkipVerify() *Agent { + if a.HostClient.TLSConfig == nil { + /* #nosec G402 */ + a.HostClient.TLSConfig = &tls.Config{InsecureSkipVerify: true} + } else { + /* #nosec G402 */ + a.HostClient.TLSConfig.InsecureSkipVerify = true + } + + return a +} + +// TLSConfig sets tls config. +func (a *Agent) TLSConfig(config *tls.Config) *Agent { + a.HostClient.TLSConfig = config + + return a +} + +// MaxRedirectsCount sets max redirect count for GET and HEAD. +func (a *Agent) MaxRedirectsCount(count int) *Agent { + a.maxRedirectsCount = count + + return a +} + +// JSONEncoder sets custom json encoder. +func (a *Agent) JSONEncoder(jsonEncoder utils.JSONMarshal) *Agent { + a.jsonEncoder = jsonEncoder + + return a +} + +// JSONDecoder sets custom json decoder. +func (a *Agent) JSONDecoder(jsonDecoder utils.JSONUnmarshal) *Agent { + a.jsonDecoder = jsonDecoder + + return a +} + +// Request returns Agent request instance. +func (a *Agent) Request() *Request { + return a.req +} + +// SetResponse sets custom response for the Agent instance. +// +// It is recommended obtaining custom response via AcquireResponse and release it +// manually in performance-critical code. +func (a *Agent) SetResponse(customResp *Response) *Agent { + a.resp = customResp + + return a +} + +// Dest sets custom dest. +// +// The contents of dest will be replaced by the response body, if the dest +// is too small a new slice will be allocated. +func (a *Agent) Dest(dest []byte) *Agent { + a.dest = dest + + return a +} + +/************************** End Agent Setting **************************/ + +// Bytes returns the status code, bytes body and errors of url. +func (a *Agent) Bytes() (code int, body []byte, errs []error) { + fmt.Println("[Warning] client is still in beta, API might change in the future!") + + defer a.release() + + if errs = append(errs, a.errs...); len(errs) > 0 { + return + } + + var ( + req = a.req + resp *Response + nilResp bool + ) + + if a.resp == nil { + resp = AcquireResponse() + nilResp = true + } else { + resp = a.resp + } + + defer func() { + if a.debugWriter != nil { + printDebugInfo(req, resp, a.debugWriter) + } + + if len(errs) == 0 { + code = resp.StatusCode() + } + + body = append(a.dest, resp.Body()...) + + if nilResp { + ReleaseResponse(resp) + } + }() + + if a.timeout > 0 { + if err := a.HostClient.DoTimeout(req, resp, a.timeout); err != nil { + errs = append(errs, err) + return + } + } + + if a.maxRedirectsCount > 0 && (string(req.Header.Method()) == MethodGet || string(req.Header.Method()) == MethodHead) { + if err := a.HostClient.DoRedirects(req, resp, a.maxRedirectsCount); err != nil { + errs = append(errs, err) + return + } + } + + if err := a.HostClient.Do(req, resp); err != nil { + errs = append(errs, err) + } + + return +} + +func printDebugInfo(req *Request, resp *Response, w io.Writer) { + msg := fmt.Sprintf("Connected to %s(%s)\r\n\r\n", req.URI().Host(), resp.RemoteAddr()) + _, _ = w.Write(getBytes(msg)) + _, _ = req.WriteTo(w) + _, _ = resp.WriteTo(w) +} + +// String returns the status code, string body and errors of url. +func (a *Agent) String() (int, string, []error) { + code, body, errs := a.Bytes() + + return code, getString(body), errs +} + +// Struct returns the status code, bytes body and errors of url. +// And bytes body will be unmarshalled to given v. +func (a *Agent) Struct(v interface{}) (code int, body []byte, errs []error) { + if a.jsonDecoder == nil { + a.jsonDecoder = json.Unmarshal + } + + if code, body, errs = a.Bytes(); len(errs) > 0 { + return + } + + if err := a.jsonDecoder(body, v); err != nil { + errs = append(errs, err) + } + + return +} + +func (a *Agent) release() { + if !a.reuse { + ReleaseAgent(a) + } else { + a.errs = a.errs[:0] + } +} + +func (a *Agent) reset() { + a.HostClient = nil + a.req.Reset() + a.resp = nil + a.dest = nil + a.timeout = 0 + a.args = nil + a.errs = a.errs[:0] + a.debugWriter = nil + a.mw = nil + a.reuse = false + a.parsed = false + a.maxRedirectsCount = 0 + a.boundary = "" + a.Name = "" + a.NoDefaultUserAgentHeader = false + for i, ff := range a.formFiles { + if ff.autoRelease { + ReleaseFormFile(ff) + } + a.formFiles[i] = nil + } + a.formFiles = a.formFiles[:0] +} + +var ( + clientPool sync.Pool + agentPool sync.Pool + responsePool sync.Pool + argsPool sync.Pool + formFilePool sync.Pool +) + +// AcquireClient returns an empty Client instance from client pool. +// +// The returned Client instance may be passed to ReleaseClient when it is +// no longer needed. This allows Client recycling, reduces GC pressure +// and usually improves performance. +func AcquireClient() *Client { + v := clientPool.Get() + if v == nil { + return &Client{} + } + return v.(*Client) +} + +// ReleaseClient returns c acquired via AcquireClient to client pool. +// +// It is forbidden accessing req and/or its' members after returning +// it to client pool. +func ReleaseClient(c *Client) { + c.UserAgent = "" + c.NoDefaultUserAgentHeader = false + + clientPool.Put(c) +} + +// AcquireAgent returns an empty Agent instance from Agent pool. +// +// The returned Agent instance may be passed to ReleaseAgent when it is +// no longer needed. This allows Agent recycling, reduces GC pressure +// and usually improves performance. +func AcquireAgent() *Agent { + v := agentPool.Get() + if v == nil { + return &Agent{req: &Request{}} + } + return v.(*Agent) +} + +// ReleaseAgent returns a acquired via AcquireAgent to Agent pool. +// +// It is forbidden accessing req and/or its' members after returning +// it to Agent pool. +func ReleaseAgent(a *Agent) { + a.reset() + agentPool.Put(a) +} + +// AcquireResponse returns an empty Response instance from response pool. +// +// The returned Response instance may be passed to ReleaseResponse when it is +// no longer needed. This allows Response recycling, reduces GC pressure +// and usually improves performance. +// Copy from fasthttp +func AcquireResponse() *Response { + v := responsePool.Get() + if v == nil { + return &Response{} + } + return v.(*Response) +} + +// ReleaseResponse return resp acquired via AcquireResponse to response pool. +// +// It is forbidden accessing resp and/or its' members after returning +// it to response pool. +// Copy from fasthttp +func ReleaseResponse(resp *Response) { + resp.Reset() + responsePool.Put(resp) +} + +// AcquireArgs returns an empty Args object from the pool. +// +// The returned Args may be returned to the pool with ReleaseArgs +// when no longer needed. This allows reducing GC load. +// Copy from fasthttp +func AcquireArgs() *Args { + v := argsPool.Get() + if v == nil { + return &Args{} + } + return v.(*Args) +} + +// ReleaseArgs returns the object acquired via AcquireArgs to the pool. +// +// String not access the released Args object, otherwise data races may occur. +// Copy from fasthttp +func ReleaseArgs(a *Args) { + a.Reset() + argsPool.Put(a) +} + +// AcquireFormFile returns an empty FormFile object from the pool. +// +// The returned FormFile may be returned to the pool with ReleaseFormFile +// when no longer needed. This allows reducing GC load. +func AcquireFormFile() *FormFile { + v := formFilePool.Get() + if v == nil { + return &FormFile{} + } + return v.(*FormFile) +} + +// ReleaseFormFile returns the object acquired via AcquireFormFile to the pool. +// +// String not access the released FormFile object, otherwise data races may occur. +func ReleaseFormFile(ff *FormFile) { + ff.Fieldname = "" + ff.Name = "" + ff.Content = ff.Content[:0] + ff.autoRelease = false + + formFilePool.Put(ff) +} + +var ( + strHTTP = []byte("http") + strHTTPS = []byte("https") + defaultUserAgent = "fiber" +) + +type multipartWriter interface { + Boundary() string + SetBoundary(boundary string) error + CreateFormFile(fieldname, filename string) (io.Writer, error) + WriteField(fieldname, value string) error + Close() error +} diff --git a/client_test.go b/client_test.go new file mode 100644 index 00000000..10aa97e3 --- /dev/null +++ b/client_test.go @@ -0,0 +1,1086 @@ +package fiber + +import ( + "bytes" + "crypto/tls" + "encoding/base64" + "errors" + "io" + "io/ioutil" + "mime/multipart" + "net" + "path/filepath" + "regexp" + "strings" + "testing" + "time" + + "github.com/gofiber/fiber/v2/internal/encoding/json" + + "github.com/gofiber/fiber/v2/utils" + "github.com/valyala/fasthttp/fasthttputil" +) + +func Test_Client_Invalid_URL(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + app := New(Config{DisableStartupMessage: true}) + + app.Get("/", func(c *Ctx) error { + return c.SendString(c.Hostname()) + }) + + go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + + a := Get("http://example.com\r\n\r\nGET /\r\n\r\n") + + a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + + _, body, errs := a.String() + + utils.AssertEqual(t, "", body) + utils.AssertEqual(t, 1, len(errs)) + utils.AssertEqual(t, "missing required Host header in request", errs[0].Error()) +} + +func Test_Client_Unsupported_Protocol(t *testing.T) { + t.Parallel() + + a := Get("ftp://example.com") + + _, body, errs := a.String() + + utils.AssertEqual(t, "", body) + utils.AssertEqual(t, 1, len(errs)) + utils.AssertEqual(t, `unsupported protocol "ftp". http and https are supported`, + errs[0].Error()) +} + +func Test_Client_Get(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + app := New(Config{DisableStartupMessage: true}) + + app.Get("/", func(c *Ctx) error { + return c.SendString(c.Hostname()) + }) + + go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + + for i := 0; i < 5; i++ { + a := Get("http://example.com") + + a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + + code, body, errs := a.String() + + utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, "example.com", body) + utils.AssertEqual(t, 0, len(errs)) + } +} + +func Test_Client_Head(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + app := New(Config{DisableStartupMessage: true}) + + app.Get("/", func(c *Ctx) error { + return c.SendString(c.Hostname()) + }) + + go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + + for i := 0; i < 5; i++ { + a := Head("http://example.com") + + a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + + code, body, errs := a.String() + + utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, "", body) + utils.AssertEqual(t, 0, len(errs)) + } +} + +func Test_Client_Post(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + app := New(Config{DisableStartupMessage: true}) + + app.Post("/", func(c *Ctx) error { + return c.Status(StatusCreated). + SendString(c.FormValue("foo")) + }) + + go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + + for i := 0; i < 5; i++ { + args := AcquireArgs() + + args.Set("foo", "bar") + + a := Post("http://example.com"). + Form(args) + + a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + + code, body, errs := a.String() + + utils.AssertEqual(t, StatusCreated, code) + utils.AssertEqual(t, "bar", body) + utils.AssertEqual(t, 0, len(errs)) + + ReleaseArgs(args) + } +} + +func Test_Client_Put(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + app := New(Config{DisableStartupMessage: true}) + + app.Put("/", func(c *Ctx) error { + return c.SendString(c.FormValue("foo")) + }) + + go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + + for i := 0; i < 5; i++ { + args := AcquireArgs() + + args.Set("foo", "bar") + + a := Put("http://example.com"). + Form(args) + + a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + + code, body, errs := a.String() + + utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, "bar", body) + utils.AssertEqual(t, 0, len(errs)) + + ReleaseArgs(args) + } +} + +func Test_Client_Patch(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + app := New(Config{DisableStartupMessage: true}) + + app.Patch("/", func(c *Ctx) error { + return c.SendString(c.FormValue("foo")) + }) + + go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + + for i := 0; i < 5; i++ { + args := AcquireArgs() + + args.Set("foo", "bar") + + a := Patch("http://example.com"). + Form(args) + + a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + + code, body, errs := a.String() + + utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, "bar", body) + utils.AssertEqual(t, 0, len(errs)) + + ReleaseArgs(args) + } +} + +func Test_Client_Delete(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + app := New(Config{DisableStartupMessage: true}) + + app.Delete("/", func(c *Ctx) error { + return c.Status(StatusNoContent). + SendString("deleted") + }) + + go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + + for i := 0; i < 5; i++ { + args := AcquireArgs() + + a := Delete("http://example.com") + + a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + + code, body, errs := a.String() + + utils.AssertEqual(t, StatusNoContent, code) + utils.AssertEqual(t, "", body) + utils.AssertEqual(t, 0, len(errs)) + + ReleaseArgs(args) + } +} + +func Test_Client_UserAgent(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + app := New(Config{DisableStartupMessage: true}) + + app.Get("/", func(c *Ctx) error { + return c.Send(c.Request().Header.UserAgent()) + }) + + go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + + t.Run("default", func(t *testing.T) { + for i := 0; i < 5; i++ { + a := Get("http://example.com") + + a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + + code, body, errs := a.String() + + utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, defaultUserAgent, body) + utils.AssertEqual(t, 0, len(errs)) + } + }) + + t.Run("custom", func(t *testing.T) { + for i := 0; i < 5; i++ { + c := AcquireClient() + c.UserAgent = "ua" + + a := c.Get("http://example.com") + + a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + + code, body, errs := a.String() + + utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, "ua", body) + utils.AssertEqual(t, 0, len(errs)) + ReleaseClient(c) + } + }) +} + +func Test_Client_Agent_Set_Or_Add_Headers(t *testing.T) { + handler := func(c *Ctx) error { + c.Request().Header.VisitAll(func(key, value []byte) { + if k := string(key); k == "K1" || k == "K2" { + _, _ = c.Write(key) + _, _ = c.Write(value) + } + }) + return nil + } + + wrapAgent := func(a *Agent) { + a.Set("k1", "v1"). + SetBytesK([]byte("k1"), "v1"). + SetBytesV("k1", []byte("v1")). + AddBytesK([]byte("k1"), "v11"). + AddBytesV("k1", []byte("v22")). + AddBytesKV([]byte("k1"), []byte("v33")). + SetBytesKV([]byte("k2"), []byte("v2")). + Add("k2", "v22") + + } + + testAgent(t, handler, wrapAgent, "K1v1K1v11K1v22K1v33K2v2K2v22") +} + +func Test_Client_Agent_Connection_Close(t *testing.T) { + handler := func(c *Ctx) error { + if c.Request().Header.ConnectionClose() { + return c.SendString("close") + } + return c.SendString("not close") + } + + wrapAgent := func(a *Agent) { + a.ConnectionClose() + } + + testAgent(t, handler, wrapAgent, "close") +} + +func Test_Client_Agent_UserAgent(t *testing.T) { + handler := func(c *Ctx) error { + return c.Send(c.Request().Header.UserAgent()) + } + + wrapAgent := func(a *Agent) { + a.UserAgent("ua"). + UserAgentBytes([]byte("ua")) + } + + testAgent(t, handler, wrapAgent, "ua") +} + +func Test_Client_Agent_Cookie(t *testing.T) { + handler := func(c *Ctx) error { + return c.SendString( + c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3") + c.Cookies("k4")) + } + + wrapAgent := func(a *Agent) { + a.Cookie("k1", "v1"). + CookieBytesK([]byte("k2"), "v2"). + CookieBytesKV([]byte("k2"), []byte("v2")). + Cookies("k3", "v3", "k4", "v4"). + CookiesBytesKV([]byte("k3"), []byte("v3"), []byte("k4"), []byte("v4")) + } + + testAgent(t, handler, wrapAgent, "v1v2v3v4") +} + +func Test_Client_Agent_Referer(t *testing.T) { + handler := func(c *Ctx) error { + return c.Send(c.Request().Header.Referer()) + } + + wrapAgent := func(a *Agent) { + a.Referer("http://referer.com"). + RefererBytes([]byte("http://referer.com")) + } + + testAgent(t, handler, wrapAgent, "http://referer.com") +} + +func Test_Client_Agent_ContentType(t *testing.T) { + handler := func(c *Ctx) error { + return c.Send(c.Request().Header.ContentType()) + } + + wrapAgent := func(a *Agent) { + a.ContentType("custom-type"). + ContentTypeBytes([]byte("custom-type")) + } + + testAgent(t, handler, wrapAgent, "custom-type") +} + +func Test_Client_Agent_Host(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + app := New(Config{DisableStartupMessage: true}) + + app.Get("/", func(c *Ctx) error { + return c.SendString(c.Hostname()) + }) + + go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + + a := Get("http://1.1.1.1:8080"). + Host("example.com"). + HostBytes([]byte("example.com")) + + utils.AssertEqual(t, "1.1.1.1:8080", a.HostClient.Addr) + + a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + + code, body, errs := a.String() + + utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, "example.com", body) + utils.AssertEqual(t, 0, len(errs)) +} + +func Test_Client_Agent_QueryString(t *testing.T) { + handler := func(c *Ctx) error { + return c.Send(c.Request().URI().QueryString()) + } + + wrapAgent := func(a *Agent) { + a.QueryString("foo=bar&bar=baz"). + QueryStringBytes([]byte("foo=bar&bar=baz")) + } + + testAgent(t, handler, wrapAgent, "foo=bar&bar=baz") +} + +func Test_Client_Agent_BasicAuth(t *testing.T) { + handler := func(c *Ctx) error { + // Get authorization header + auth := c.Get(HeaderAuthorization) + // Decode the header contents + raw, err := base64.StdEncoding.DecodeString(auth[6:]) + utils.AssertEqual(t, nil, err) + + return c.Send(raw) + } + + wrapAgent := func(a *Agent) { + a.BasicAuth("foo", "bar"). + BasicAuthBytes([]byte("foo"), []byte("bar")) + } + + testAgent(t, handler, wrapAgent, "foo:bar") +} + +func Test_Client_Agent_BodyString(t *testing.T) { + handler := func(c *Ctx) error { + return c.Send(c.Request().Body()) + } + + wrapAgent := func(a *Agent) { + a.BodyString("foo=bar&bar=baz") + } + + testAgent(t, handler, wrapAgent, "foo=bar&bar=baz") +} + +func Test_Client_Agent_Body(t *testing.T) { + handler := func(c *Ctx) error { + return c.Send(c.Request().Body()) + } + + wrapAgent := func(a *Agent) { + a.Body([]byte("foo=bar&bar=baz")) + } + + testAgent(t, handler, wrapAgent, "foo=bar&bar=baz") +} + +func Test_Client_Agent_BodyStream(t *testing.T) { + handler := func(c *Ctx) error { + return c.Send(c.Request().Body()) + } + + wrapAgent := func(a *Agent) { + a.BodyStream(strings.NewReader("body stream"), -1) + } + + testAgent(t, handler, wrapAgent, "body stream") +} + +func Test_Client_Agent_Custom_Response(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + app := New(Config{DisableStartupMessage: true}) + + app.Get("/", func(c *Ctx) error { + return c.SendString("custom") + }) + + go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + + for i := 0; i < 5; i++ { + a := AcquireAgent() + resp := AcquireResponse() + + req := a.Request() + req.Header.SetMethod(MethodGet) + req.SetRequestURI("http://example.com") + + utils.AssertEqual(t, nil, a.Parse()) + + a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + + code, body, errs := a.SetResponse(resp). + String() + + utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, "custom", body) + utils.AssertEqual(t, "custom", string(resp.Body())) + utils.AssertEqual(t, 0, len(errs)) + + ReleaseResponse(resp) + } +} + +func Test_Client_Agent_Dest(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + app := New(Config{DisableStartupMessage: true}) + + app.Get("/", func(c *Ctx) error { + return c.SendString("dest") + }) + + go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + + t.Run("small dest", func(t *testing.T) { + dest := []byte("de") + + a := Get("http://example.com") + + a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + + code, body, errs := a.Dest(dest[:0]).String() + + utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, "dest", body) + utils.AssertEqual(t, "de", string(dest)) + utils.AssertEqual(t, 0, len(errs)) + }) + + t.Run("enough dest", func(t *testing.T) { + dest := []byte("foobar") + + a := Get("http://example.com") + + a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + + code, body, errs := a.Dest(dest[:0]).String() + + utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, "dest", body) + utils.AssertEqual(t, "destar", string(dest)) + utils.AssertEqual(t, 0, len(errs)) + }) +} + +func Test_Client_Agent_Json(t *testing.T) { + handler := func(c *Ctx) error { + utils.AssertEqual(t, MIMEApplicationJSON, string(c.Request().Header.ContentType())) + + return c.Send(c.Request().Body()) + } + + wrapAgent := func(a *Agent) { + a.JSON(data{Success: true}) + } + + testAgent(t, handler, wrapAgent, `{"success":true}`) +} + +func Test_Client_Agent_Json_Error(t *testing.T) { + a := Get("http://example.com"). + JSONEncoder(json.Marshal). + JSON(complex(1, 1)) + + _, body, errs := a.String() + + utils.AssertEqual(t, "", body) + utils.AssertEqual(t, 1, len(errs)) + utils.AssertEqual(t, "json: unsupported type: complex128", errs[0].Error()) +} + +func Test_Client_Agent_XML(t *testing.T) { + handler := func(c *Ctx) error { + utils.AssertEqual(t, MIMEApplicationXML, string(c.Request().Header.ContentType())) + + return c.Send(c.Request().Body()) + } + + wrapAgent := func(a *Agent) { + a.XML(data{Success: true}) + } + + testAgent(t, handler, wrapAgent, "true") +} + +func Test_Client_Agent_XML_Error(t *testing.T) { + a := Get("http://example.com"). + XML(complex(1, 1)) + + _, body, errs := a.String() + + utils.AssertEqual(t, "", body) + utils.AssertEqual(t, 1, len(errs)) + utils.AssertEqual(t, "xml: unsupported type: complex128", errs[0].Error()) +} + +func Test_Client_Agent_Form(t *testing.T) { + handler := func(c *Ctx) error { + utils.AssertEqual(t, MIMEApplicationForm, string(c.Request().Header.ContentType())) + + return c.Send(c.Request().Body()) + } + + args := AcquireArgs() + + args.Set("foo", "bar") + + wrapAgent := func(a *Agent) { + a.Form(args) + } + + testAgent(t, handler, wrapAgent, "foo=bar") + + ReleaseArgs(args) +} + +func Test_Client_Agent_MultipartForm(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + app := New(Config{DisableStartupMessage: true}) + + app.Post("/", func(c *Ctx) error { + utils.AssertEqual(t, "multipart/form-data; boundary=myBoundary", c.Get(HeaderContentType)) + + mf, err := c.MultipartForm() + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "bar", mf.Value["foo"][0]) + + return c.Send(c.Request().Body()) + }) + + go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + + args := AcquireArgs() + + args.Set("foo", "bar") + + a := Post("http://example.com"). + Boundary("myBoundary"). + MultipartForm(args) + + a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + + code, body, errs := a.String() + + utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, "--myBoundary\r\nContent-Disposition: form-data; name=\"foo\"\r\n\r\nbar\r\n--myBoundary--\r\n", body) + utils.AssertEqual(t, 0, len(errs)) + ReleaseArgs(args) +} + +func Test_Client_Agent_MultipartForm_Errors(t *testing.T) { + t.Parallel() + + a := AcquireAgent() + a.mw = &errorMultipartWriter{} + + args := AcquireArgs() + args.Set("foo", "bar") + + ff1 := &FormFile{"", "name1", []byte("content"), false} + ff2 := &FormFile{"", "name2", []byte("content"), false} + a.FileData(ff1, ff2). + MultipartForm(args) + + utils.AssertEqual(t, 4, len(a.errs)) + ReleaseArgs(args) +} + +func Test_Client_Agent_MultipartForm_SendFiles(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + app := New(Config{DisableStartupMessage: true}) + + app.Post("/", func(c *Ctx) error { + utils.AssertEqual(t, "multipart/form-data; boundary=myBoundary", c.Get(HeaderContentType)) + + fh1, err := c.FormFile("field1") + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fh1.Filename, "name") + buf := make([]byte, fh1.Size, fh1.Size) + f, err := fh1.Open() + utils.AssertEqual(t, nil, err) + defer func() { _ = f.Close() }() + _, err = f.Read(buf) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "form file", string(buf)) + + fh2, err := c.FormFile("index") + utils.AssertEqual(t, nil, err) + checkFormFile(t, fh2, ".github/testdata/index.html") + + fh3, err := c.FormFile("file3") + utils.AssertEqual(t, nil, err) + checkFormFile(t, fh3, ".github/testdata/index.tmpl") + + return c.SendString("multipart form files") + }) + + go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + + for i := 0; i < 5; i++ { + ff := AcquireFormFile() + ff.Fieldname = "field1" + ff.Name = "name" + ff.Content = []byte("form file") + + a := Post("http://example.com"). + Boundary("myBoundary"). + FileData(ff). + SendFiles(".github/testdata/index.html", "index", ".github/testdata/index.tmpl"). + MultipartForm(nil) + + a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + + code, body, errs := a.String() + + utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, "multipart form files", body) + utils.AssertEqual(t, 0, len(errs)) + + ReleaseFormFile(ff) + } +} + +func checkFormFile(t *testing.T, fh *multipart.FileHeader, filename string) { + basename := filepath.Base(filename) + utils.AssertEqual(t, fh.Filename, basename) + + b1, err := ioutil.ReadFile(filename) + utils.AssertEqual(t, nil, err) + + b2 := make([]byte, fh.Size, fh.Size) + f, err := fh.Open() + utils.AssertEqual(t, nil, err) + defer func() { _ = f.Close() }() + _, err = f.Read(b2) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, b1, b2) +} + +func Test_Client_Agent_Multipart_Random_Boundary(t *testing.T) { + t.Parallel() + + a := Post("http://example.com"). + MultipartForm(nil) + + reg := regexp.MustCompile(`multipart/form-data; boundary=\w{30}`) + + utils.AssertEqual(t, true, reg.Match(a.req.Header.Peek(HeaderContentType))) +} + +func Test_Client_Agent_Multipart_Invalid_Boundary(t *testing.T) { + t.Parallel() + + a := Post("http://example.com"). + Boundary("*"). + MultipartForm(nil) + + utils.AssertEqual(t, 1, len(a.errs)) + utils.AssertEqual(t, "mime: invalid boundary character", a.errs[0].Error()) +} + +func Test_Client_Agent_SendFile_Error(t *testing.T) { + t.Parallel() + + a := Post("http://example.com"). + SendFile("non-exist-file!", "") + + utils.AssertEqual(t, 1, len(a.errs)) + utils.AssertEqual(t, true, strings.Contains(a.errs[0].Error(), "open non-exist-file!")) +} + +func Test_Client_Debug(t *testing.T) { + handler := func(c *Ctx) error { + return c.SendString("debug") + } + + var output bytes.Buffer + + wrapAgent := func(a *Agent) { + a.Debug(&output) + } + + testAgent(t, handler, wrapAgent, "debug", 1) + + str := output.String() + + utils.AssertEqual(t, true, strings.Contains(str, "Connected to example.com(pipe)")) + utils.AssertEqual(t, true, strings.Contains(str, "GET / HTTP/1.1")) + utils.AssertEqual(t, true, strings.Contains(str, "User-Agent: fiber")) + utils.AssertEqual(t, true, strings.Contains(str, "Host: example.com\r\n\r\n")) + utils.AssertEqual(t, true, strings.Contains(str, "HTTP/1.1 200 OK")) + utils.AssertEqual(t, true, strings.Contains(str, "Content-Type: text/plain; charset=utf-8\r\nContent-Length: 5\r\n\r\ndebug")) +} + +func Test_Client_Agent_Timeout(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + app := New(Config{DisableStartupMessage: true}) + + app.Get("/", func(c *Ctx) error { + time.Sleep(time.Millisecond * 200) + return c.SendString("timeout") + }) + + go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + + a := Get("http://example.com"). + Timeout(time.Millisecond * 100) + + a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + + _, body, errs := a.String() + + utils.AssertEqual(t, "", body) + utils.AssertEqual(t, 1, len(errs)) + utils.AssertEqual(t, "timeout", errs[0].Error()) +} + +func Test_Client_Agent_Reuse(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + app := New(Config{DisableStartupMessage: true}) + + app.Get("/", func(c *Ctx) error { + return c.SendString("reuse") + }) + + go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + + a := Get("http://example.com"). + Reuse() + + a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + + code, body, errs := a.String() + + utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, "reuse", body) + utils.AssertEqual(t, 0, len(errs)) + + code, body, errs = a.String() + + utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, "reuse", body) + utils.AssertEqual(t, 0, len(errs)) +} + +func Test_Client_Agent_TLS(t *testing.T) { + t.Parallel() + + // Create tls certificate + cer, err := tls.LoadX509KeyPair("./.github/testdata/ssl.pem", "./.github/testdata/ssl.key") + utils.AssertEqual(t, nil, err) + + config := &tls.Config{ + Certificates: []tls.Certificate{cer}, + } + + ln, err := net.Listen(NetworkTCP4, "127.0.0.1:0") + utils.AssertEqual(t, nil, err) + + ln = tls.NewListener(ln, config) + + app := New(Config{DisableStartupMessage: true}) + + app.Get("/", func(c *Ctx) error { + return c.SendString("tls") + }) + + go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + + code, body, errs := Get("https://" + ln.Addr().String()). + InsecureSkipVerify(). + TLSConfig(config). + InsecureSkipVerify(). + String() + + utils.AssertEqual(t, 0, len(errs)) + utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, "tls", body) +} + +func Test_Client_Agent_MaxRedirectsCount(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + app := New(Config{DisableStartupMessage: true}) + + app.Get("/", func(c *Ctx) error { + if c.Request().URI().QueryArgs().Has("foo") { + return c.Redirect("/foo") + } + return c.Redirect("/") + }) + app.Get("/foo", func(c *Ctx) error { + return c.SendString("redirect") + }) + + go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + + t.Run("success", func(t *testing.T) { + a := Get("http://example.com?foo"). + MaxRedirectsCount(1) + + a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + + code, body, errs := a.String() + + utils.AssertEqual(t, 200, code) + utils.AssertEqual(t, "redirect", body) + utils.AssertEqual(t, 0, len(errs)) + }) + + t.Run("error", func(t *testing.T) { + a := Get("http://example.com"). + MaxRedirectsCount(1) + + a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + + _, body, errs := a.String() + + utils.AssertEqual(t, "", body) + utils.AssertEqual(t, 1, len(errs)) + utils.AssertEqual(t, "too many redirects detected when doing the request", errs[0].Error()) + }) +} + +func Test_Client_Agent_Struct(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + app := New(Config{DisableStartupMessage: true}) + + app.Get("/", func(c *Ctx) error { + return c.JSON(data{true}) + }) + + app.Get("/error", func(c *Ctx) error { + return c.SendString(`{"success"`) + }) + + go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + + t.Run("success", func(t *testing.T) { + a := Get("http://example.com") + + a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + + var d data + + code, body, errs := a.Struct(&d) + + utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, `{"success":true}`, string(body)) + utils.AssertEqual(t, 0, len(errs)) + utils.AssertEqual(t, true, d.Success) + }) + + t.Run("pre error", func(t *testing.T) { + a := Get("http://example.com") + + a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + a.errs = append(a.errs, errors.New("pre errors")) + + var d data + _, body, errs := a.Struct(&d) + + utils.AssertEqual(t, "", string(body)) + utils.AssertEqual(t, 1, len(errs)) + utils.AssertEqual(t, "pre errors", errs[0].Error()) + utils.AssertEqual(t, false, d.Success) + }) + + t.Run("error", func(t *testing.T) { + a := Get("http://example.com/error") + + a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + + var d data + + code, body, errs := a.JSONDecoder(json.Unmarshal).Struct(&d) + + utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, `{"success"`, string(body)) + utils.AssertEqual(t, 1, len(errs)) + utils.AssertEqual(t, "json: unexpected end of JSON input after object field key: ", errs[0].Error()) + }) +} + +func Test_Client_Agent_Parse(t *testing.T) { + t.Parallel() + + a := Get("https://example.com:10443") + + utils.AssertEqual(t, nil, a.Parse()) +} + +func Test_AddMissingPort_TLS(t *testing.T) { + addr := addMissingPort("example.com", true) + utils.AssertEqual(t, "example.com:443", addr) +} + +func testAgent(t *testing.T, handler Handler, wrapAgent func(agent *Agent), excepted string, count ...int) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + app := New(Config{DisableStartupMessage: true}) + + app.Get("/", handler) + + go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + + c := 1 + if len(count) > 0 { + c = count[0] + } + + for i := 0; i < c; i++ { + a := Get("http://example.com") + + wrapAgent(a) + + a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + + code, body, errs := a.String() + + utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, excepted, body) + utils.AssertEqual(t, 0, len(errs)) + } +} + +type data struct { + Success bool `json:"success" xml:"success"` +} + +type errorMultipartWriter struct { + count int +} + +func (e *errorMultipartWriter) Boundary() string { return "myBoundary" } +func (e *errorMultipartWriter) SetBoundary(_ string) error { return nil } +func (e *errorMultipartWriter) CreateFormFile(_, _ string) (io.Writer, error) { + if e.count == 0 { + e.count++ + return nil, errors.New("CreateFormFile error") + } + return errorWriter{}, nil +} +func (e *errorMultipartWriter) WriteField(_, _ string) error { return errors.New("WriteField error") } +func (e *errorMultipartWriter) Close() error { return errors.New("Close error") } + +type errorWriter struct{} + +func (errorWriter) Write(_ []byte) (int, error) { return 0, errors.New("Write error") } diff --git a/utils/json.go b/utils/json.go new file mode 100644 index 00000000..477c8c33 --- /dev/null +++ b/utils/json.go @@ -0,0 +1,9 @@ +package utils + +// JSONMarshal returns the JSON encoding of v. +type JSONMarshal func(v interface{}) ([]byte, error) + +// JSONUnmarshal parses the JSON-encoded data and stores the result +// in the value pointed to by v. If v is nil or not a pointer, +// Unmarshal returns an InvalidUnmarshalError. +type JSONUnmarshal func(data []byte, v interface{}) error diff --git a/utils/json_marshal.go b/utils/json_marshal.go deleted file mode 100644 index 692b49a4..00000000 --- a/utils/json_marshal.go +++ /dev/null @@ -1,5 +0,0 @@ -package utils - -// JSONMarshal is the standard definition of representing a Go structure in -// json format -type JSONMarshal func(interface{}) ([]byte, error) diff --git a/utils/json_marshal_test.go b/utils/json_marshal_test.go deleted file mode 100644 index 08501d96..00000000 --- a/utils/json_marshal_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package utils - -import ( - "encoding/json" - "testing" -) - -func TestDefaultJSONEncoder(t *testing.T) { - type SampleStructure struct { - ImportantString string `json:"important_string"` - } - - var ( - sampleStructure = &SampleStructure{ - ImportantString: "Hello World", - } - importantString = `{"important_string":"Hello World"}` - - jsonEncoder JSONMarshal = json.Marshal - ) - - raw, err := jsonEncoder(sampleStructure) - AssertEqual(t, err, nil) - - AssertEqual(t, string(raw), importantString) -} diff --git a/utils/json_test.go b/utils/json_test.go new file mode 100644 index 00000000..966faa83 --- /dev/null +++ b/utils/json_test.go @@ -0,0 +1,41 @@ +package utils + +import ( + "encoding/json" + "testing" +) + +type sampleStructure struct { + ImportantString string `json:"important_string"` +} + +func Test_DefaultJSONEncoder(t *testing.T) { + t.Parallel() + + var ( + ss = &sampleStructure{ + ImportantString: "Hello World", + } + importantString = `{"important_string":"Hello World"}` + jsonEncoder JSONMarshal = json.Marshal + ) + + raw, err := jsonEncoder(ss) + AssertEqual(t, err, nil) + + AssertEqual(t, string(raw), importantString) +} + +func Test_DefaultJSONDecoder(t *testing.T) { + t.Parallel() + + var ( + ss sampleStructure + importantString = []byte(`{"important_string":"Hello World"}`) + jsonDecoder JSONUnmarshal = json.Unmarshal + ) + + err := jsonDecoder(importantString, &ss) + AssertEqual(t, err, nil) + AssertEqual(t, "Hello World", ss.ImportantString) +}