v3 (feature): client refactor (#1986)

*  v3: Move the client module to the client folder and fix the error

*  v3: add xml encoder and decoder

* 🚧 v3: design plugin and hook mechanism, complete simple get request

* 🚧 v3: reset add some field

* 🚧 v3: add doc and fix some error

* 🚧 v3: add header merge

* 🚧 v3: add query param

* 🚧 v3: change to fasthttp's header and args

*  v3: add body and ua setting

* 🚧 v3: add cookie support

* 🚧 v3: add path param support

*  v3: fix error test case

* 🚧 v3: add formdata and file support

* 🚧 v3: referer support

* 🚧 v3: reponse unmarshal

*  v3: finish API design

* 🔥 v3: remove plugin mechanism

* 🚧 v3: add timeout

* 🚧 v3: change path params pattern and add unit test for core

* ✏️ v3: error spell

*  v3: improve test coverage

*  perf: change test func name to fit project format

* 🚧 v3: handle error

* 🚧 v3: add unit test and fix error

* ️ chore: change func to improve performance

*  v3: add some unit test

*  v3: fix error test

* 🐛 fix: add cookie to response

*  v3: add unit test

*  v3: export raw field

* 🐛 fix: fix data race

* 🔒️ chore: change package

* 🐛 fix: data race

* 🐛 fix: test fail

*  feat: move core to req

* 🐛 fix: connection reuse

* 🐛 fix: data race

* 🐛 fix: data race

* 🔀 fix: change to testify

*  fix: fail test in windows

*  feat: response body save to file

*  feat: support tls config

* 🐛 fix: add err check

* 🎨 perf: fix some static check

*  feat: add proxy support

*  feat: add retry feature

* 🐛 fix: static check error

* 🎨 refactor: move som code

* docs: change readme

*  feat: extend axios API

* perf: change field to export field

*  chore: disable startup message

* 🐛 fix: fix test error

* chore: fix error test

* chore: fix test case

* feat: add some test to client

* chore: add test case

* chore: add test case

*  feat: add peek for client

*  chore: add test case

* ️ feat: lazy generate rand string

* 🚧 perf: add config test case

* 🐛 fix: fix merge error

* 🐛 fix utils error

*  add redirection

* 🔥 chore: delete deps

* perf: fix spell error

* 🎨 perf: spell error

*  feat: add logger

*  feat: add cookie jar

*  feat: logger with level

* 🎨 perf: change the field name

* perf: add jar test

* fix proxy test

* improve test coverage

* fix proxy tests

* add cookiejar support from pending fasthttp PR

* fix some lint errors.

* add benchmark for SetValWithStruct

* optimize

* update

* fix proxy middleware

* use panicf instead of errorf and fix panic on default logger

* update

* update

* cleanup comments

* cleanup comments

* fix golang-lint errors

* Update helper_test.go

* add more test cases

* add hostclient pool

* make it more thread safe
-> there is still something which is shared between the requests

* fixed some golangci-lint errors

* fix Test_Request_FormData test

* create new test suite

* just create client for once

* use random port instead of 3000

* remove client pooling and fix test suite

* fix data races on logger tests

* fix proxy tests

* fix global tests

* remove unused code

* fix logger test

* fix proxy tests

* fix linter

* use lock instead of rlock

* fix cookiejar data-race

* fix(client): race conditions

* fix(client): race conditions

* apply some reviews

* change client property name

* apply review

* add parallel benchmark for simple request

* apply review

* apply review

* fix log tests

* fix linter

* fix(client): return error in SetProxyURL instead of panic

---------

Co-authored-by: Muhammed Efe Çetin <efectn@protonmail.com>
Co-authored-by: René Werner <rene.werner@verivox.com>
Co-authored-by: Joey <fenny@gofiber.io>
Co-authored-by: René <rene@gofiber.io>
pull/2896/head
Jinquan Wang 2024-03-04 15:49:14 +08:00 committed by GitHub
parent 67d35dc068
commit b38be4bcb3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 7884 additions and 2431 deletions

1021
client.go

File diff suppressed because it is too large Load Diff

35
client/README.md Normal file
View File

@ -0,0 +1,35 @@
<h1 align="center">Fiber Client</h1>
<p align="center">Easy-to-use HTTP client based on fasthttp (inspired by <a href="https://github.com/go-resty/resty">resty</a> and <a href="https://github.com/axios/axios">axios</a>)</p>
<p align="center"><a href="#features">Features</a> section describes in detail about Resty capabilities</p>
## Features
> The characteristics have not yet been written.
- GET, POST, PUT, DELETE, HEAD, PATCH, OPTIONS, etc.
- Simple and chainable methods for settings and request
- Request Body can be `string`, `[]byte`, `map`, `slice`
- Auto detects `Content-Type`
- Buffer processing for `files`
- Native `*fasthttp.Request` instance can be accessed during middleware and request execution via `Request.RawRequest`
- Request Body can be read multiple time via `Request.RawRequest.GetBody()`
- Response object gives you more possibility
- Access as `[]byte` by `response.Body()` or access as `string` by `response.String()`
- Automatic marshal and unmarshal for JSON and XML content type
- Default is JSON, if you supply struct/map without header Content-Type
- For auto-unmarshal, refer to -
- Success scenario Request.SetResult() and Response.Result().
- Error scenario Request.SetError() and Response.Error().
- Supports RFC7807 - application/problem+json & application/problem+xml
- Provide an option to override JSON Marshal/Unmarshal and XML Marshal/Unmarshal
## Usage
The following samples will assist you to become as comfortable as possible with `Fiber Client` library.
```go
// Import Fiber Client into your code and refer it as `client`.
import "github.com/gofiber/fiber/client"
```
### Simple GET

775
client/client.go Normal file
View File

@ -0,0 +1,775 @@
package client
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"encoding/xml"
"errors"
"fmt"
"io"
urlpkg "net/url"
"os"
"path/filepath"
"sync"
"time"
"github.com/gofiber/fiber/v3/log"
"github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp"
)
var (
ErrInvalidProxyURL = errors.New("invalid proxy url scheme")
ErrFailedToAppendCert = errors.New("failed to append certificate")
)
// The Client is used to create a Fiber Client with
// client-level settings that apply to all requests
// raise from the client.
//
// Fiber Client also provides an option to override
// or merge most of the client settings at the request.
type Client struct {
mu sync.RWMutex
fasthttp *fasthttp.Client
baseURL string
userAgent string
referer string
header *Header
params *QueryParam
cookies *Cookie
path *PathParam
debug bool
timeout time.Duration
// user defined request hooks
userRequestHooks []RequestHook
// client package defined request hooks
builtinRequestHooks []RequestHook
// user defined response hooks
userResponseHooks []ResponseHook
// client package defined response hooks
builtinResponseHooks []ResponseHook
jsonMarshal utils.JSONMarshal
jsonUnmarshal utils.JSONUnmarshal
xmlMarshal utils.XMLMarshal
xmlUnmarshal utils.XMLUnmarshal
cookieJar *CookieJar
// proxy
proxyURL string
// retry
retryConfig *RetryConfig
// logger
logger log.CommonLogger
}
// R raise a request from the client.
func (c *Client) R() *Request {
return AcquireRequest().SetClient(c)
}
// RequestHook Request returns user-defined request hooks.
func (c *Client) RequestHook() []RequestHook {
return c.userRequestHooks
}
// AddRequestHook Add user-defined request hooks.
func (c *Client) AddRequestHook(h ...RequestHook) *Client {
c.mu.Lock()
defer c.mu.Unlock()
c.userRequestHooks = append(c.userRequestHooks, h...)
return c
}
// ResponseHook return user-define response hooks.
func (c *Client) ResponseHook() []ResponseHook {
return c.userResponseHooks
}
// AddResponseHook Add user-defined response hooks.
func (c *Client) AddResponseHook(h ...ResponseHook) *Client {
c.mu.Lock()
defer c.mu.Unlock()
c.userResponseHooks = append(c.userResponseHooks, h...)
return c
}
// JSONMarshal returns json marshal function in Core.
func (c *Client) JSONMarshal() utils.JSONMarshal {
return c.jsonMarshal
}
// SetJSONMarshal Set json encoder.
func (c *Client) SetJSONMarshal(f utils.JSONMarshal) *Client {
c.jsonMarshal = f
return c
}
// JSONUnmarshal returns json unmarshal function in Core.
func (c *Client) JSONUnmarshal() utils.JSONUnmarshal {
return c.jsonUnmarshal
}
// Set json decoder.
func (c *Client) SetJSONUnmarshal(f utils.JSONUnmarshal) *Client {
c.jsonUnmarshal = f
return c
}
// XMLMarshal returns xml marshal function in Core.
func (c *Client) XMLMarshal() utils.XMLMarshal {
return c.xmlMarshal
}
// SetXMLMarshal Set xml encoder.
func (c *Client) SetXMLMarshal(f utils.XMLMarshal) *Client {
c.xmlMarshal = f
return c
}
// XMLUnmarshal returns xml unmarshal function in Core.
func (c *Client) XMLUnmarshal() utils.XMLUnmarshal {
return c.xmlUnmarshal
}
// SetXMLUnmarshal Set xml decoder.
func (c *Client) SetXMLUnmarshal(f utils.XMLUnmarshal) *Client {
c.xmlUnmarshal = f
return c
}
// TLSConfig returns tlsConfig in client.
// If client don't have tlsConfig, this function will init it.
func (c *Client) TLSConfig() *tls.Config {
if c.fasthttp.TLSConfig == nil {
c.fasthttp.TLSConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
}
}
return c.fasthttp.TLSConfig
}
// SetTLSConfig sets tlsConfig in client.
func (c *Client) SetTLSConfig(config *tls.Config) *Client {
c.fasthttp.TLSConfig = config
return c
}
// SetCertificates method sets client certificates into client.
func (c *Client) SetCertificates(certs ...tls.Certificate) *Client {
config := c.TLSConfig()
config.Certificates = append(config.Certificates, certs...)
return c
}
// SetRootCertificate adds one or more root certificates into client.
func (c *Client) SetRootCertificate(path string) *Client {
cleanPath := filepath.Clean(path)
file, err := os.Open(cleanPath)
if err != nil {
c.logger.Panicf("client: %v", err)
}
defer func() {
if err := file.Close(); err != nil {
c.logger.Panicf("client: failed to close file: %v", err)
}
}()
pem, err := io.ReadAll(file)
if err != nil {
c.logger.Panicf("client: %v", err)
}
config := c.TLSConfig()
if config.RootCAs == nil {
config.RootCAs = x509.NewCertPool()
}
if !config.RootCAs.AppendCertsFromPEM(pem) {
c.logger.Panicf("client: %v", ErrFailedToAppendCert)
}
return c
}
// SetRootCertificateFromString method adds one or more root certificates into client.
func (c *Client) SetRootCertificateFromString(pem string) *Client {
config := c.TLSConfig()
if config.RootCAs == nil {
config.RootCAs = x509.NewCertPool()
}
if !config.RootCAs.AppendCertsFromPEM([]byte(pem)) {
c.logger.Panicf("client: %v", ErrFailedToAppendCert)
}
return c
}
// SetProxyURL sets proxy url in client. It will apply via core to hostclient.
func (c *Client) SetProxyURL(proxyURL string) error {
pURL, err := urlpkg.Parse(proxyURL)
if err != nil {
return fmt.Errorf("client: %w", err)
}
if pURL.Scheme != "http" && pURL.Scheme != "https" {
return fmt.Errorf("client: %w", ErrInvalidProxyURL)
}
c.proxyURL = pURL.String()
return nil
}
// RetryConfig returns retry config in client.
func (c *Client) RetryConfig() *RetryConfig {
return c.retryConfig
}
// SetRetryConfig sets retry config in client which is impl by addon/retry package.
func (c *Client) SetRetryConfig(config *RetryConfig) *Client {
c.mu.Lock()
defer c.mu.Unlock()
c.retryConfig = config
return c
}
// BaseURL returns baseurl in Client instance.
func (c *Client) BaseURL() string {
return c.baseURL
}
// SetBaseURL Set baseUrl which is prefix of real url.
func (c *Client) SetBaseURL(url string) *Client {
c.baseURL = url
return c
}
// Header method returns header value via key,
// this method will visit all field in the header,
// then sort them.
func (c *Client) Header(key string) []string {
return c.header.PeekMultiple(key)
}
// AddHeader method adds a single header field and its value in the client instance.
// These headers will be applied to all requests raised from this client instance.
// Also, it can be overridden at request level header options.
func (c *Client) AddHeader(key, val string) *Client {
c.header.Add(key, val)
return c
}
// SetHeader method sets a single header field and its value in the client instance.
// These headers will be applied to all requests raised from this client instance.
// Also, it can be overridden at request level header options.
func (c *Client) SetHeader(key, val string) *Client {
c.header.Set(key, val)
return c
}
// AddHeaders method adds multiple headers field and its values at one go in the client instance.
// These headers will be applied to all requests raised from this client instance. Also it can be
// overridden at request level headers options.
func (c *Client) AddHeaders(h map[string][]string) *Client {
c.header.AddHeaders(h)
return c
}
// SetHeaders method sets multiple headers field and its values at one go in the client instance.
// These headers will be applied to all requests raised from this client instance. Also it can be
// overridden at request level headers options.
func (c *Client) SetHeaders(h map[string]string) *Client {
c.header.SetHeaders(h)
return c
}
// Param method returns params value via key,
// this method will visit all field in the query param.
func (c *Client) Param(key string) []string {
res := []string{}
tmp := c.params.PeekMulti(key)
for _, v := range tmp {
res = append(res, utils.UnsafeString(v))
}
return res
}
// AddParam method adds a single query param field and its value in the client instance.
// These params will be applied to all requests raised from this client instance.
// Also, it can be overridden at request level param options.
func (c *Client) AddParam(key, val string) *Client {
c.params.Add(key, val)
return c
}
// SetParam method sets a single query param field and its value in the client instance.
// These params will be applied to all requests raised from this client instance.
// Also, it can be overridden at request level param options.
func (c *Client) SetParam(key, val string) *Client {
c.params.Set(key, val)
return c
}
// AddParams method adds multiple query params field and its values at one go in the client instance.
// These params will be applied to all requests raised from this client instance. Also it can be
// overridden at request level params options.
func (c *Client) AddParams(m map[string][]string) *Client {
c.params.AddParams(m)
return c
}
// SetParams method sets multiple params field and its values at one go in the client instance.
// These params will be applied to all requests raised from this client instance. Also it can be
// overridden at request level params options.
func (c *Client) SetParams(m map[string]string) *Client {
c.params.SetParams(m)
return c
}
// SetParamsWithStruct method sets multiple params field and its values at one go in the client instance.
// These params will be applied to all requests raised from this client instance. Also it can be
// overridden at request level params options.
func (c *Client) SetParamsWithStruct(v any) *Client {
c.params.SetParamsWithStruct(v)
return c
}
// DelParams method deletes single or multiple params field and its values in client.
func (c *Client) DelParams(key ...string) *Client {
for _, v := range key {
c.params.Del(v)
}
return c
}
// SetUserAgent method sets userAgent field and its value in the client instance.
// This ua will be applied to all requests raised from this client instance.
// Also it can be overridden at request level ua options.
func (c *Client) SetUserAgent(ua string) *Client {
c.userAgent = ua
return c
}
// SetReferer method sets referer field and its value in the client instance.
// This referer will be applied to all requests raised from this client instance.
// Also it can be overridden at request level referer options.
func (c *Client) SetReferer(r string) *Client {
c.referer = r
return c
}
// PathParam returns the path param be set in request instance.
// if path param doesn't exist, return empty string.
func (c *Client) PathParam(key string) string {
if val, ok := (*c.path)[key]; ok {
return val
}
return ""
}
// SetPathParam method sets a single path param field and its value in the client instance.
// These path params will be applied to all requests raised from this client instance.
// Also it can be overridden at request level path params options.
func (c *Client) SetPathParam(key, val string) *Client {
c.path.SetParam(key, val)
return c
}
// SetPathParams method sets multiple path params field and its values at one go in the client instance.
// These path params will be applied to all requests raised from this client instance. Also it can be
// overridden at request level path params options.
func (c *Client) SetPathParams(m map[string]string) *Client {
c.path.SetParams(m)
return c
}
// SetPathParamsWithStruct method sets multiple path params field and its values at one go in the client instance.
// These path params will be applied to all requests raised from this client instance. Also it can be
// overridden at request level path params options.
func (c *Client) SetPathParamsWithStruct(v any) *Client {
c.path.SetParamsWithStruct(v)
return c
}
// DelPathParams method deletes single or multiple path params field and its values in client.
func (c *Client) DelPathParams(key ...string) *Client {
c.path.DelParams(key...)
return c
}
// Cookie returns the cookie be set in request instance.
// if cookie doesn't exist, return empty string.
func (c *Client) Cookie(key string) string {
if val, ok := (*c.cookies)[key]; ok {
return val
}
return ""
}
// SetCookie method sets a single cookie field and its value in the client instance.
// These cookies will be applied to all requests raised from this client instance.
// Also it can be overridden at request level cookie options.
func (c *Client) SetCookie(key, val string) *Client {
c.cookies.SetCookie(key, val)
return c
}
// SetCookies method sets multiple cookies field and its values at one go in the client instance.
// These cookies will be applied to all requests raised from this client instance. Also it can be
// overridden at request level cookie options.
func (c *Client) SetCookies(m map[string]string) *Client {
c.cookies.SetCookies(m)
return c
}
// SetCookiesWithStruct method sets multiple cookies field and its values at one go in the client instance.
// These cookies will be applied to all requests raised from this client instance. Also it can be
// overridden at request level cookies options.
func (c *Client) SetCookiesWithStruct(v any) *Client {
c.cookies.SetCookiesWithStruct(v)
return c
}
// DelCookies method deletes single or multiple cookies field and its values in client.
func (c *Client) DelCookies(key ...string) *Client {
c.cookies.DelCookies(key...)
return c
}
// SetTimeout method sets timeout val in client instance.
// This value will be applied to all requests raised from this client instance.
// Also, it can be overridden at request level timeout options.
func (c *Client) SetTimeout(t time.Duration) *Client {
c.timeout = t
return c
}
// Debug enable log debug level output.
func (c *Client) Debug() *Client {
c.debug = true
return c
}
// DisableDebug disenable log debug level output.
func (c *Client) DisableDebug() *Client {
c.debug = false
return c
}
// SetCookieJar sets cookie jar in client instance.
func (c *Client) SetCookieJar(cookieJar *CookieJar) *Client {
c.cookieJar = cookieJar
return c
}
// Get provide an API like axios which send get request.
func (c *Client) Get(url string, cfg ...Config) (*Response, error) {
req := AcquireRequest().SetClient(c)
setConfigToRequest(req, cfg...)
return req.Get(url)
}
// Post provide an API like axios which send post request.
func (c *Client) Post(url string, cfg ...Config) (*Response, error) {
req := AcquireRequest().SetClient(c)
setConfigToRequest(req, cfg...)
return req.Post(url)
}
// Head provide a API like axios which send head request.
func (c *Client) Head(url string, cfg ...Config) (*Response, error) {
req := AcquireRequest().SetClient(c)
setConfigToRequest(req, cfg...)
return req.Head(url)
}
// Put provide an API like axios which send put request.
func (c *Client) Put(url string, cfg ...Config) (*Response, error) {
req := AcquireRequest().SetClient(c)
setConfigToRequest(req, cfg...)
return req.Put(url)
}
// Delete provide an API like axios which send delete request.
func (c *Client) Delete(url string, cfg ...Config) (*Response, error) {
req := AcquireRequest().SetClient(c)
setConfigToRequest(req, cfg...)
return req.Delete(url)
}
// Options provide an API like axios which send options request.
func (c *Client) Options(url string, cfg ...Config) (*Response, error) {
req := AcquireRequest().SetClient(c)
setConfigToRequest(req, cfg...)
return req.Options(url)
}
// Patch provide an API like axios which send patch request.
func (c *Client) Patch(url string, cfg ...Config) (*Response, error) {
req := AcquireRequest().SetClient(c)
setConfigToRequest(req, cfg...)
return req.Patch(url)
}
// Custom provide an API like axios which send custom request.
func (c *Client) Custom(url, method string, cfg ...Config) (*Response, error) {
req := AcquireRequest().SetClient(c)
setConfigToRequest(req, cfg...)
return req.Custom(url, method)
}
// SetDial sets dial function in client.
func (c *Client) SetDial(dial fasthttp.DialFunc) *Client {
c.mu.Lock()
defer c.mu.Unlock()
c.fasthttp.Dial = dial
return c
}
// SetLogger sets logger instance in client.
func (c *Client) SetLogger(logger log.CommonLogger) *Client {
c.mu.Lock()
defer c.mu.Unlock()
c.logger = logger
return c
}
// Logger returns logger instance of client.
func (c *Client) Logger() log.CommonLogger {
return c.logger
}
// Reset clear Client object
func (c *Client) Reset() {
c.fasthttp = &fasthttp.Client{}
c.baseURL = ""
c.timeout = 0
c.userAgent = ""
c.referer = ""
c.proxyURL = ""
c.retryConfig = nil
c.debug = false
if c.cookieJar != nil {
c.cookieJar.Release()
c.cookieJar = nil
}
c.path.Reset()
c.cookies.Reset()
c.header.Reset()
c.params.Reset()
}
// Config for easy to set the request parameters, it should be
// noted that when setting the request body will use JSON as
// the default serialization mechanism, while the priority of
// Body is higher than FormData, and the priority of FormData
// is higher than File.
type Config struct {
Ctx context.Context //nolint:containedctx // It's needed to be stored in the config.
UserAgent string
Referer string
Header map[string]string
Param map[string]string
Cookie map[string]string
PathParam map[string]string
Timeout time.Duration
MaxRedirects int
Body any
FormData map[string]string
File []*File
}
// setConfigToRequest Set the parameters passed via Config to Request.
func setConfigToRequest(req *Request, config ...Config) {
if len(config) == 0 {
return
}
cfg := config[0]
if cfg.Ctx != nil {
req.SetContext(cfg.Ctx)
}
if cfg.UserAgent != "" {
req.SetUserAgent(cfg.UserAgent)
}
if cfg.Referer != "" {
req.SetReferer(cfg.Referer)
}
if cfg.Header != nil {
req.SetHeaders(cfg.Header)
}
if cfg.Param != nil {
req.SetParams(cfg.Param)
}
if cfg.Cookie != nil {
req.SetCookies(cfg.Cookie)
}
if cfg.PathParam != nil {
req.SetPathParams(cfg.PathParam)
}
if cfg.Timeout != 0 {
req.SetTimeout(cfg.Timeout)
}
if cfg.MaxRedirects != 0 {
req.SetMaxRedirects(cfg.MaxRedirects)
}
if cfg.Body != nil {
req.SetJSON(cfg.Body)
return
}
if cfg.FormData != nil {
req.SetFormDatas(cfg.FormData)
return
}
if cfg.File != nil && len(cfg.File) != 0 {
req.AddFiles(cfg.File...)
return
}
}
var (
defaultClient *Client
replaceMu = sync.Mutex{}
defaultUserAgent = "fiber"
)
// init acquire a default client.
func init() {
defaultClient = NewClient()
}
// NewClient creates and returns a new Client object.
func NewClient() *Client {
// FOllOW-UP performance optimization
// trie to use a pool to reduce the cost of memory allocation
// for the fiber client and the fasthttp client
// if possible also for other structs -> request header, cookie, query param, path param...
return &Client{
fasthttp: &fasthttp.Client{},
header: &Header{
RequestHeader: &fasthttp.RequestHeader{},
},
params: &QueryParam{
Args: fasthttp.AcquireArgs(),
},
cookies: &Cookie{},
path: &PathParam{},
userRequestHooks: []RequestHook{},
builtinRequestHooks: []RequestHook{parserRequestURL, parserRequestHeader, parserRequestBody},
userResponseHooks: []ResponseHook{},
builtinResponseHooks: []ResponseHook{parserResponseCookie, logger},
jsonMarshal: json.Marshal,
jsonUnmarshal: json.Unmarshal,
xmlMarshal: xml.Marshal,
xmlUnmarshal: xml.Unmarshal,
logger: log.DefaultLogger(),
}
}
// C get default client.
func C() *Client {
return defaultClient
}
// Replace the defaultClient, the returned function can undo.
func Replace(c *Client) func() {
replaceMu.Lock()
defer replaceMu.Unlock()
oldClient := defaultClient
defaultClient = c
return func() {
replaceMu.Lock()
defer replaceMu.Unlock()
defaultClient = oldClient
}
}
// Get send a get request use defaultClient, a convenient method.
func Get(url string, cfg ...Config) (*Response, error) {
return C().Get(url, cfg...)
}
// Post send a post request use defaultClient, a convenient method.
func Post(url string, cfg ...Config) (*Response, error) {
return C().Post(url, cfg...)
}
// Head send a head request use defaultClient, a convenient method.
func Head(url string, cfg ...Config) (*Response, error) {
return C().Head(url, cfg...)
}
// Put send a put request use defaultClient, a convenient method.
func Put(url string, cfg ...Config) (*Response, error) {
return C().Put(url, cfg...)
}
// Delete send a delete request use defaultClient, a convenient method.
func Delete(url string, cfg ...Config) (*Response, error) {
return C().Delete(url, cfg...)
}
// Options send a options request use defaultClient, a convenient method.
func Options(url string, cfg ...Config) (*Response, error) {
return C().Options(url, cfg...)
}
// Patch send a patch request use defaultClient, a convenient method.
func Patch(url string, cfg ...Config) (*Response, error) {
return C().Patch(url, cfg...)
}

1642
client/client_test.go Normal file

File diff suppressed because it is too large Load Diff

245
client/cookiejar.go Normal file
View File

@ -0,0 +1,245 @@
// The code has been taken from https://github.com/valyala/fasthttp/pull/526 originally.
package client
import (
"bytes"
"errors"
"net"
"sync"
"time"
"github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp"
)
var cookieJarPool = sync.Pool{
New: func() any {
return &CookieJar{}
},
}
// AcquireCookieJar returns an empty CookieJar object from pool.
func AcquireCookieJar() *CookieJar {
jar, ok := cookieJarPool.Get().(*CookieJar)
if !ok {
panic(errors.New("failed to type-assert to *CookieJar"))
}
return jar
}
// ReleaseCookieJar returns CookieJar to the pool.
func ReleaseCookieJar(c *CookieJar) {
c.Release()
cookieJarPool.Put(c)
}
// CookieJar manages cookie storage. It is used by the client to store cookies.
type CookieJar struct {
mu sync.Mutex
hostCookies map[string][]*fasthttp.Cookie
}
// Get returns the cookies stored from a specific domain.
// If there were no cookies related with host returned slice will be nil.
//
// CookieJar keeps a copy of the cookies, so the returned cookies can be released safely.
func (cj *CookieJar) Get(uri *fasthttp.URI) []*fasthttp.Cookie {
if uri == nil {
return nil
}
return cj.getByHostAndPath(uri.Host(), uri.Path())
}
// get returns the cookies stored from a specific host and path.
func (cj *CookieJar) getByHostAndPath(host, path []byte) []*fasthttp.Cookie {
if cj.hostCookies == nil {
return nil
}
var (
err error
cookies []*fasthttp.Cookie
hostStr = utils.UnsafeString(host)
)
// port must not be included.
hostStr, _, err = net.SplitHostPort(hostStr)
if err != nil {
hostStr = utils.UnsafeString(host)
}
// get cookies deleting expired ones
cookies = cj.getCookiesByHost(hostStr)
newCookies := make([]*fasthttp.Cookie, 0, len(cookies))
for i := 0; i < len(cookies); i++ {
cookie := cookies[i]
if len(path) > 1 && len(cookie.Path()) > 1 && !bytes.HasPrefix(cookie.Path(), path) {
continue
}
newCookies = append(newCookies, cookie)
}
return newCookies
}
// getCookiesByHost returns the cookies stored from a specific host.
// If cookies are expired they will be deleted.
func (cj *CookieJar) getCookiesByHost(host string) []*fasthttp.Cookie {
cj.mu.Lock()
defer cj.mu.Unlock()
now := time.Now()
cookies := cj.hostCookies[host]
for i := 0; i < len(cookies); i++ {
c := cookies[i]
if !c.Expire().Equal(fasthttp.CookieExpireUnlimited) && c.Expire().Before(now) { // release cookie if expired
cookies = append(cookies[:i], cookies[i+1:]...)
fasthttp.ReleaseCookie(c)
i--
}
}
return cookies
}
// Set sets cookies for a specific host.
// The host is get from uri.Host().
// If the cookie key already exists it will be replaced by the new cookie value.
//
// CookieJar keeps a copy of the cookies, so the parsed cookies can be released safely.
func (cj *CookieJar) Set(uri *fasthttp.URI, cookies ...*fasthttp.Cookie) {
if uri == nil {
return
}
cj.SetByHost(uri.Host(), cookies...)
}
// SetByHost sets cookies for a specific host.
// If the cookie key already exists it will be replaced by the new cookie value.
//
// CookieJar keeps a copy of the cookies, so the parsed cookies can be released safely.
func (cj *CookieJar) SetByHost(host []byte, cookies ...*fasthttp.Cookie) {
hostStr := utils.UnsafeString(host)
cj.mu.Lock()
defer cj.mu.Unlock()
if cj.hostCookies == nil {
cj.hostCookies = make(map[string][]*fasthttp.Cookie)
}
hostCookies, ok := cj.hostCookies[hostStr]
if !ok {
// If the key does not exist in the map, then we must make a copy for the key to avoid unsafe usage.
hostStr = string(host)
}
for _, cookie := range cookies {
c := searchCookieByKeyAndPath(cookie.Key(), cookie.Path(), hostCookies)
if c == nil {
// If the cookie does not exist in the slice, let's acquire new cookie and store it.
c = fasthttp.AcquireCookie()
hostCookies = append(hostCookies, c)
}
c.CopyTo(cookie) // override cookie properties
}
cj.hostCookies[hostStr] = hostCookies
}
// SetKeyValue sets a cookie by key and value for a specific host.
//
// This function prevents extra allocations by making repeated cookies
// not being duplicated.
func (cj *CookieJar) SetKeyValue(host, key, value string) {
c := fasthttp.AcquireCookie()
c.SetKey(key)
c.SetValue(value)
cj.SetByHost(utils.UnsafeBytes(host), c)
}
// SetKeyValueBytes sets a cookie by key and value for a specific host.
//
// This function prevents extra allocations by making repeated cookies
// not being duplicated.
func (cj *CookieJar) SetKeyValueBytes(host string, key, value []byte) {
c := fasthttp.AcquireCookie()
c.SetKeyBytes(key)
c.SetValueBytes(value)
cj.SetByHost(utils.UnsafeBytes(host), c)
}
// dumpCookiesToReq dumps the stored cookies to the request.
func (cj *CookieJar) dumpCookiesToReq(req *fasthttp.Request) {
uri := req.URI()
cookies := cj.getByHostAndPath(uri.Host(), uri.Path())
for _, cookie := range cookies {
req.Header.SetCookieBytesKV(cookie.Key(), cookie.Value())
}
}
// parseCookiesFromResp parses the response cookies and stores them.
func (cj *CookieJar) parseCookiesFromResp(host, path []byte, resp *fasthttp.Response) {
hostStr := utils.UnsafeString(host)
cj.mu.Lock()
defer cj.mu.Unlock()
if cj.hostCookies == nil {
cj.hostCookies = make(map[string][]*fasthttp.Cookie)
}
cookies, ok := cj.hostCookies[hostStr]
if !ok {
// If the key does not exist in the map then
// we must make a copy for the key to avoid unsafe usage.
hostStr = string(host)
}
now := time.Now()
resp.Header.VisitAllCookie(func(key, value []byte) {
isCreated := false
c := searchCookieByKeyAndPath(key, path, cookies)
if c == nil {
c, isCreated = fasthttp.AcquireCookie(), true
}
_ = c.ParseBytes(value) //nolint:errcheck // ignore error
if c.Expire().Equal(fasthttp.CookieExpireUnlimited) || c.Expire().After(now) {
cookies = append(cookies, c)
} else if isCreated {
fasthttp.ReleaseCookie(c)
}
})
cj.hostCookies[hostStr] = cookies
}
// Release releases all cookie values.
func (cj *CookieJar) Release() {
// FOllOW-UP performance optimization
// currently a race condition is found because the reset method modifies a value which is not a copy but a reference -> solution should be to make a copy
// for _, v := range cj.hostCookies {
// for _, c := range v {
// fasthttp.ReleaseCookie(c)
// }
// }
cj.hostCookies = nil
}
// searchCookieByKeyAndPath searches for a cookie by key and path.
func searchCookieByKeyAndPath(key, path []byte, cookies []*fasthttp.Cookie) *fasthttp.Cookie {
for _, c := range cookies {
if bytes.Equal(key, c.Key()) {
if len(path) <= 1 || bytes.HasPrefix(c.Path(), path) {
return c
}
}
}
return nil
}

213
client/cookiejar_test.go Normal file
View File

@ -0,0 +1,213 @@
package client
import (
"bytes"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
func checkKeyValue(t *testing.T, cj *CookieJar, cookie *fasthttp.Cookie, uri *fasthttp.URI, n int) {
t.Helper()
cs := cj.Get(uri)
require.GreaterOrEqual(t, len(cs), n)
c := cs[n-1]
require.NotNil(t, c)
require.Equal(t, string(c.Key()), string(cookie.Key()))
require.Equal(t, string(c.Value()), string(cookie.Value()))
}
func TestCookieJarGet(t *testing.T) {
t.Parallel()
url := []byte("http://fasthttp.com/")
url1 := []byte("http://fasthttp.com/make")
url11 := []byte("http://fasthttp.com/hola")
url2 := []byte("http://fasthttp.com/make/fasthttp")
url3 := []byte("http://fasthttp.com/make/fasthttp/great")
prefix := []byte("/")
prefix1 := []byte("/make")
prefix2 := []byte("/make/fasthttp")
prefix3 := []byte("/make/fasthttp/great")
cj := &CookieJar{}
c1 := &fasthttp.Cookie{}
c1.SetKey("k")
c1.SetValue("v")
c1.SetPath("/make/")
c2 := &fasthttp.Cookie{}
c2.SetKey("kk")
c2.SetValue("vv")
c2.SetPath("/make/fasthttp")
c3 := &fasthttp.Cookie{}
c3.SetKey("kkk")
c3.SetValue("vvv")
c3.SetPath("/make/fasthttp/great")
uri := fasthttp.AcquireURI()
require.NoError(t, uri.Parse(nil, url))
uri1 := fasthttp.AcquireURI()
require.NoError(t, uri1.Parse(nil, url1))
uri11 := fasthttp.AcquireURI()
require.NoError(t, uri11.Parse(nil, url11))
uri2 := fasthttp.AcquireURI()
require.NoError(t, uri2.Parse(nil, url2))
uri3 := fasthttp.AcquireURI()
require.NoError(t, uri3.Parse(nil, url3))
cj.Set(uri1, c1, c2, c3)
cookies := cj.Get(uri1)
require.Len(t, cookies, 3)
for _, cookie := range cookies {
require.True(t, bytes.HasPrefix(cookie.Path(), prefix1))
}
cookies = cj.Get(uri11)
require.Empty(t, cookies)
cookies = cj.Get(uri2)
require.Len(t, cookies, 2)
for _, cookie := range cookies {
require.True(t, bytes.HasPrefix(cookie.Path(), prefix2))
}
cookies = cj.Get(uri3)
require.Len(t, cookies, 1)
for _, cookie := range cookies {
require.True(t, bytes.HasPrefix(cookie.Path(), prefix3))
}
cookies = cj.Get(uri)
require.Len(t, cookies, 3)
for _, cookie := range cookies {
require.True(t, bytes.HasPrefix(cookie.Path(), prefix))
}
}
func TestCookieJarGetExpired(t *testing.T) {
t.Parallel()
url1 := []byte("http://fasthttp.com/make/")
uri1 := fasthttp.AcquireURI()
require.NoError(t, uri1.Parse(nil, url1))
c1 := &fasthttp.Cookie{}
c1.SetKey("k")
c1.SetValue("v")
c1.SetExpire(time.Now().Add(-time.Hour))
cj := &CookieJar{}
cj.Set(uri1, c1)
cookies := cj.Get(uri1)
require.Empty(t, cookies)
}
func TestCookieJarSet(t *testing.T) {
t.Parallel()
url := []byte("http://fasthttp.com/hello/world")
cj := &CookieJar{}
cookie := &fasthttp.Cookie{}
cookie.SetKey("k")
cookie.SetValue("v")
uri := fasthttp.AcquireURI()
require.NoError(t, uri.Parse(nil, url))
cj.Set(uri, cookie)
checkKeyValue(t, cj, cookie, uri, 1)
}
func TestCookieJarSetRepeatedCookieKeys(t *testing.T) {
t.Parallel()
host := "fast.http"
cj := &CookieJar{}
uri := fasthttp.AcquireURI()
uri.SetHost(host)
cookie := &fasthttp.Cookie{}
cookie.SetKey("k")
cookie.SetValue("v")
cookie2 := &fasthttp.Cookie{}
cookie2.SetKey("k")
cookie2.SetValue("v2")
cookie3 := &fasthttp.Cookie{}
cookie3.SetKey("key")
cookie3.SetValue("value")
cj.Set(uri, cookie, cookie2, cookie3)
cookies := cj.Get(uri)
require.Len(t, cookies, 2)
require.Equal(t, cookies[0], cookie2)
require.True(t, bytes.Equal(cookies[0].Value(), cookie2.Value()))
}
func TestCookieJarSetKeyValue(t *testing.T) {
t.Parallel()
host := "fast.http"
cj := &CookieJar{}
uri := fasthttp.AcquireURI()
uri.SetHost(host)
cj.SetKeyValue(host, "k", "v")
cj.SetKeyValue(host, "key", "value")
cj.SetKeyValue(host, "k", "vv")
cj.SetKeyValue(host, "key", "value2")
cookies := cj.Get(uri)
require.Len(t, cookies, 2)
}
func TestCookieJarGetFromResponse(t *testing.T) {
t.Parallel()
res := fasthttp.AcquireResponse()
host := []byte("fast.http")
uri := fasthttp.AcquireURI()
uri.SetHostBytes(host)
c := &fasthttp.Cookie{}
c.SetKey("key")
c.SetValue("val")
c2 := &fasthttp.Cookie{}
c2.SetKey("k")
c2.SetValue("v")
c3 := &fasthttp.Cookie{}
c3.SetKey("kk")
c3.SetValue("vv")
res.Header.SetStatusCode(200)
res.Header.SetCookie(c)
res.Header.SetCookie(c2)
res.Header.SetCookie(c3)
cj := &CookieJar{}
cj.parseCookiesFromResp(host, nil, res)
cookies := cj.Get(uri)
require.Len(t, cookies, 3)
}

272
client/core.go Normal file
View File

@ -0,0 +1,272 @@
package client
import (
"context"
"errors"
"net"
"strconv"
"strings"
"sync"
"sync/atomic"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/addon/retry"
"github.com/valyala/fasthttp"
)
var boundary = "--FiberFormBoundary"
// RequestHook is a function that receives Agent and Request,
// it can change the data in Request and Agent.
//
// Called before a request is sent.
type RequestHook func(*Client, *Request) error
// ResponseHook is a function that receives Agent, Response and Request,
// it can change the data is Response or deal with some effects.
//
// Called after a response has been received.
type ResponseHook func(*Client, *Response, *Request) error
// RetryConfig is an alias for config in the `addon/retry` package.
type RetryConfig = retry.Config
// addMissingPort will add the corresponding port number for host.
func addMissingPort(addr string, isTLS bool) string { //revive:disable-line:flag-parameter // Accepting a bool param named isTLS if fine here
n := strings.Index(addr, ":")
if n >= 0 {
return addr
}
port := 80
if isTLS {
port = 443
}
return net.JoinHostPort(addr, strconv.Itoa(port))
}
// `core` stores middleware and plugin definitions,
// and defines the execution process
type core struct {
client *Client
req *Request
ctx context.Context //nolint:containedctx // It's needed to be stored in the core.
}
// getRetryConfig returns the retry configuration of the client.
func (c *core) getRetryConfig() *RetryConfig {
c.client.mu.RLock()
defer c.client.mu.RUnlock()
cfg := c.client.RetryConfig()
if cfg == nil {
return nil
}
return &RetryConfig{
InitialInterval: cfg.InitialInterval,
MaxBackoffTime: cfg.MaxBackoffTime,
Multiplier: cfg.Multiplier,
MaxRetryCount: cfg.MaxRetryCount,
}
}
// execFunc is the core function of the client.
// It sends the request and receives the response.
func (c *core) execFunc() (*Response, error) {
resp := AcquireResponse()
resp.setClient(c.client)
resp.setRequest(c.req)
// To avoid memory allocation reuse of data structures such as errch.
done := int32(0)
errCh, reqv := acquireErrChan(), fasthttp.AcquireRequest()
defer func() {
releaseErrChan(errCh)
}()
c.req.RawRequest.CopyTo(reqv)
cfg := c.getRetryConfig()
var err error
go func() {
respv := fasthttp.AcquireResponse()
defer func() {
fasthttp.ReleaseRequest(reqv)
fasthttp.ReleaseResponse(respv)
}()
if cfg != nil {
err = retry.NewExponentialBackoff(*cfg).Retry(func() error {
if c.req.maxRedirects > 0 && (string(reqv.Header.Method()) == fiber.MethodGet || string(reqv.Header.Method()) == fiber.MethodHead) {
return c.client.fasthttp.DoRedirects(reqv, respv, c.req.maxRedirects)
}
return c.client.fasthttp.Do(reqv, respv)
})
} else {
if c.req.maxRedirects > 0 && (string(reqv.Header.Method()) == fiber.MethodGet || string(reqv.Header.Method()) == fiber.MethodHead) {
err = c.client.fasthttp.DoRedirects(reqv, respv, c.req.maxRedirects)
} else {
err = c.client.fasthttp.Do(reqv, respv)
}
}
if atomic.CompareAndSwapInt32(&done, 0, 1) {
if err != nil {
errCh <- err
return
}
respv.CopyTo(resp.RawResponse)
errCh <- nil
}
}()
select {
case err := <-errCh:
if err != nil {
// When get error should release Response
ReleaseResponse(resp)
return nil, err
}
return resp, nil
case <-c.ctx.Done():
atomic.SwapInt32(&done, 1)
ReleaseResponse(resp)
return nil, ErrTimeoutOrCancel
}
}
// preHooks Exec request hook
func (c *core) preHooks() error {
c.client.mu.Lock()
defer c.client.mu.Unlock()
for _, f := range c.client.userRequestHooks {
err := f(c.client, c.req)
if err != nil {
return err
}
}
for _, f := range c.client.builtinRequestHooks {
err := f(c.client, c.req)
if err != nil {
return err
}
}
return nil
}
// afterHooks Exec response hooks
func (c *core) afterHooks(resp *Response) error {
c.client.mu.Lock()
defer c.client.mu.Unlock()
for _, f := range c.client.builtinResponseHooks {
err := f(c.client, resp, c.req)
if err != nil {
return err
}
}
for _, f := range c.client.userResponseHooks {
err := f(c.client, resp, c.req)
if err != nil {
return err
}
}
return nil
}
// timeout deals with timeout
func (c *core) timeout() context.CancelFunc {
var cancel context.CancelFunc
if c.req.timeout > 0 {
c.ctx, cancel = context.WithTimeout(c.ctx, c.req.timeout)
} else if c.client.timeout > 0 {
c.ctx, cancel = context.WithTimeout(c.ctx, c.client.timeout)
}
return cancel
}
// execute will exec each hooks and plugins.
func (c *core) execute(ctx context.Context, client *Client, req *Request) (*Response, error) {
// keep a reference, because pass param is boring
c.ctx = ctx
c.client = client
c.req = req
// The built-in hooks will be executed only
// after the user-defined hooks are executed.
err := c.preHooks()
if err != nil {
return nil, err
}
cancel := c.timeout()
if cancel != nil {
defer cancel()
}
// Do http request
resp, err := c.execFunc()
if err != nil {
return nil, err
}
// The built-in hooks will be executed only
// before the user-defined hooks are executed.
err = c.afterHooks(resp)
if err != nil {
resp.Close()
return nil, err
}
return resp, nil
}
var errChanPool = &sync.Pool{
New: func() any {
return make(chan error, 1)
},
}
// acquireErrChan returns an empty error chan from the pool.
//
// The returned error chan may be returned to the pool with releaseErrChan when no longer needed.
// This allows reducing GC load.
func acquireErrChan() chan error {
ch, ok := errChanPool.Get().(chan error)
if !ok {
panic(errors.New("failed to type-assert to chan error"))
}
return ch
}
// releaseErrChan returns the object acquired via acquireErrChan to the pool.
//
// Do not access the released core object, otherwise data races may occur.
func releaseErrChan(ch chan error) {
errChanPool.Put(ch)
}
// newCore returns an empty core object.
func newCore() *core {
c := &core{}
return c
}
var (
ErrTimeoutOrCancel = errors.New("timeout or cancel")
ErrURLFormat = errors.New("the url is a mistake")
ErrNotSupportSchema = errors.New("the protocol is not support, only http or https")
ErrFileNoName = errors.New("the file should have name")
ErrBodyType = errors.New("the body type should be []byte")
ErrNotSupportSaveMethod = errors.New("file path and io.Writer are supported")
)

248
client/core_test.go Normal file
View File

@ -0,0 +1,248 @@
package client
import (
"context"
"errors"
"net"
"testing"
"time"
"github.com/gofiber/fiber/v3"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp/fasthttputil"
)
func Test_AddMissing_Port(t *testing.T) {
t.Parallel()
type args struct {
addr string
isTLS bool
}
tests := []struct {
name string
args args
want string
}{
{
name: "do anything",
args: args{
addr: "example.com:1234",
},
want: "example.com:1234",
},
{
name: "add 80 port",
args: args{
addr: "example.com",
},
want: "example.com:80",
},
{
name: "add 443 port",
args: args{
addr: "example.com",
isTLS: true,
},
want: "example.com:443",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tt.want, addMissingPort(tt.args.addr, tt.args.isTLS))
})
}
}
func Test_Exec_Func(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
app := fiber.New()
app.Get("/normal", func(c fiber.Ctx) error {
return c.SendString(c.Hostname())
})
app.Get("/return-error", func(_ fiber.Ctx) error {
return errors.New("the request is error")
})
app.Get("/hang-up", func(c fiber.Ctx) error {
time.Sleep(time.Second)
return c.SendString(c.Hostname() + " hang up")
})
go func() {
require.NoError(t, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true}))
}()
time.Sleep(300 * time.Millisecond)
t.Run("normal request", func(t *testing.T) {
t.Parallel()
core, client, req := newCore(), NewClient(), AcquireRequest()
core.ctx = context.Background()
core.client = client
core.req = req
client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) //nolint:wrapcheck // not needed
req.RawRequest.SetRequestURI("http://example.com/normal")
resp, err := core.execFunc()
require.NoError(t, err)
require.Equal(t, 200, resp.RawResponse.StatusCode())
require.Equal(t, "example.com", string(resp.RawResponse.Body()))
})
t.Run("the request return an error", func(t *testing.T) {
t.Parallel()
core, client, req := newCore(), NewClient(), AcquireRequest()
core.ctx = context.Background()
core.client = client
core.req = req
client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) //nolint:wrapcheck // not needed
req.RawRequest.SetRequestURI("http://example.com/return-error")
resp, err := core.execFunc()
require.NoError(t, err)
require.Equal(t, 500, resp.RawResponse.StatusCode())
require.Equal(t, "the request is error", string(resp.RawResponse.Body()))
})
t.Run("the request timeout", func(t *testing.T) {
t.Parallel()
core, client, req := newCore(), NewClient(), AcquireRequest()
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
core.ctx = ctx
core.client = client
core.req = req
client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) //nolint:wrapcheck // not needed
req.RawRequest.SetRequestURI("http://example.com/hang-up")
_, err := core.execFunc()
require.Equal(t, ErrTimeoutOrCancel, err)
})
}
func Test_Execute(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
app := fiber.New()
app.Get("/normal", func(c fiber.Ctx) error {
return c.SendString(c.Hostname())
})
app.Get("/return-error", func(_ fiber.Ctx) error {
return errors.New("the request is error")
})
app.Get("/hang-up", func(c fiber.Ctx) error {
time.Sleep(time.Second)
return c.SendString(c.Hostname() + " hang up")
})
go func() {
require.NoError(t, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true}))
}()
t.Run("add user request hooks", func(t *testing.T) {
t.Parallel()
core, client, req := newCore(), NewClient(), AcquireRequest()
client.AddRequestHook(func(_ *Client, _ *Request) error {
require.Equal(t, "http://example.com", req.URL())
return nil
})
client.SetDial(func(_ string) (net.Conn, error) {
return ln.Dial() //nolint:wrapcheck // not needed
})
req.SetURL("http://example.com")
resp, err := core.execute(context.Background(), client, req)
require.NoError(t, err)
require.Equal(t, "Cannot GET /", string(resp.RawResponse.Body()))
})
t.Run("add user response hooks", func(t *testing.T) {
t.Parallel()
core, client, req := newCore(), NewClient(), AcquireRequest()
client.AddResponseHook(func(_ *Client, _ *Response, req *Request) error {
require.Equal(t, "http://example.com", req.URL())
return nil
})
client.SetDial(func(_ string) (net.Conn, error) {
return ln.Dial() //nolint:wrapcheck // not needed
})
req.SetURL("http://example.com")
resp, err := core.execute(context.Background(), client, req)
require.NoError(t, err)
require.Equal(t, "Cannot GET /", string(resp.RawResponse.Body()))
})
t.Run("no timeout", func(t *testing.T) {
t.Parallel()
core, client, req := newCore(), NewClient(), AcquireRequest()
client.SetDial(func(_ string) (net.Conn, error) {
return ln.Dial() //nolint:wrapcheck // not needed
})
req.SetURL("http://example.com/hang-up")
resp, err := core.execute(context.Background(), client, req)
require.NoError(t, err)
require.Equal(t, "example.com hang up", string(resp.RawResponse.Body()))
})
t.Run("client timeout", func(t *testing.T) {
t.Parallel()
core, client, req := newCore(), NewClient(), AcquireRequest()
client.SetTimeout(500 * time.Millisecond)
client.SetDial(func(_ string) (net.Conn, error) {
return ln.Dial() //nolint:wrapcheck // not needed
})
req.SetURL("http://example.com/hang-up")
_, err := core.execute(context.Background(), client, req)
require.Equal(t, ErrTimeoutOrCancel, err)
})
t.Run("request timeout", func(t *testing.T) {
t.Parallel()
core, client, req := newCore(), NewClient(), AcquireRequest()
client.SetDial(func(_ string) (net.Conn, error) {
return ln.Dial() //nolint:wrapcheck // not needed
})
req.SetURL("http://example.com/hang-up").
SetTimeout(300 * time.Millisecond)
_, err := core.execute(context.Background(), client, req)
require.Equal(t, ErrTimeoutOrCancel, err)
})
t.Run("request timeout has higher level", func(t *testing.T) {
t.Parallel()
core, client, req := newCore(), NewClient(), AcquireRequest()
client.SetTimeout(30 * time.Millisecond)
client.SetDial(func(_ string) (net.Conn, error) {
return ln.Dial() //nolint:wrapcheck // not needed
})
req.SetURL("http://example.com/hang-up").
SetTimeout(3000 * time.Millisecond)
resp, err := core.execute(context.Background(), client, req)
require.NoError(t, err)
require.Equal(t, "example.com hang up", string(resp.RawResponse.Body()))
})
}

157
client/helper_test.go Normal file
View File

@ -0,0 +1,157 @@
package client
import (
"net"
"testing"
"time"
"github.com/gofiber/fiber/v3"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp/fasthttputil"
)
type testServer struct {
app *fiber.App
ch chan struct{}
ln *fasthttputil.InmemoryListener
tb testing.TB
}
func startTestServer(tb testing.TB, beforeStarting func(app *fiber.App)) *testServer {
tb.Helper()
ln := fasthttputil.NewInmemoryListener()
app := fiber.New()
if beforeStarting != nil {
beforeStarting(app)
}
ch := make(chan struct{})
go func() {
if err := app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true}); err != nil {
tb.Fatal(err)
}
close(ch)
}()
return &testServer{
app: app,
ch: ch,
ln: ln,
tb: tb,
}
}
func (ts *testServer) stop() {
ts.tb.Helper()
if err := ts.app.Shutdown(); err != nil {
ts.tb.Fatal(err)
}
select {
case <-ts.ch:
case <-time.After(time.Second):
ts.tb.Fatalf("timeout when waiting for server close")
}
}
func (ts *testServer) dial() func(addr string) (net.Conn, error) {
ts.tb.Helper()
return func(_ string) (net.Conn, error) {
return ts.ln.Dial() //nolint:wrapcheck // not needed
}
}
func createHelperServer(tb testing.TB) (*fiber.App, func(addr string) (net.Conn, error), func()) {
tb.Helper()
ln := fasthttputil.NewInmemoryListener()
app := fiber.New()
return app, func(_ string) (net.Conn, error) {
return ln.Dial() //nolint:wrapcheck // not needed
}, func() {
require.NoError(tb, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true}))
}
}
func testRequest(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Request), excepted string, count ...int) {
t.Helper()
app, ln, start := createHelperServer(t)
app.Get("/", handler)
go start()
c := 1
if len(count) > 0 {
c = count[0]
}
client := NewClient().SetDial(ln)
for i := 0; i < c; i++ {
req := AcquireRequest().SetClient(client)
wrapAgent(req)
resp, err := req.Get("http://example.com")
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode())
require.Equal(t, excepted, resp.String())
resp.Close()
}
}
func testRequestFail(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Request), excepted error, count ...int) {
t.Helper()
app, ln, start := createHelperServer(t)
app.Get("/", handler)
go start()
c := 1
if len(count) > 0 {
c = count[0]
}
client := NewClient().SetDial(ln)
for i := 0; i < c; i++ {
req := AcquireRequest().SetClient(client)
wrapAgent(req)
_, err := req.Get("http://example.com")
require.Equal(t, excepted.Error(), err.Error())
}
}
func testClient(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Client), excepted string, count ...int) { //nolint: unparam // maybe needed
t.Helper()
app, ln, start := createHelperServer(t)
app.Get("/", handler)
go start()
c := 1
if len(count) > 0 {
c = count[0]
}
for i := 0; i < c; i++ {
client := NewClient().SetDial(ln)
wrapAgent(client)
resp, err := client.Get("http://example.com")
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode())
require.Equal(t, excepted, resp.String())
resp.Close()
}
}

328
client/hooks.go Normal file
View File

@ -0,0 +1,328 @@
package client
import (
"errors"
"fmt"
"io"
"math/rand"
"mime/multipart"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"time"
"github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp"
)
var (
protocolCheck = regexp.MustCompile(`^https?://.*$`)
headerAccept = "Accept"
applicationJSON = "application/json"
applicationXML = "application/xml"
applicationForm = "application/x-www-form-urlencoded"
multipartFormData = "multipart/form-data"
letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
letterIdxBits = 6 // 6 bits to represent a letter index
letterIdxMask = 1<<letterIdxBits - 1 // All 1-bits, as many as letterIdxBits
letterIdxMax = 63 / letterIdxBits // # of letter indices fitting in 63 bits
)
// randString returns a random string with n length
func randString(n int) string {
b := make([]byte, n)
length := len(letterBytes)
src := rand.NewSource(time.Now().UnixNano())
for i, cache, remain := n-1, src.Int63(), letterIdxMax; i >= 0; {
if remain == 0 {
cache, remain = src.Int63(), letterIdxMax
}
if idx := int(cache & int64(letterIdxMask)); idx < length {
b[i] = letterBytes[idx]
i--
}
cache >>= int64(letterIdxBits)
remain--
}
return utils.UnsafeString(b)
}
// parserRequestURL will set the options for the hostclient
// and normalize the url.
// The baseUrl will be merge with request uri.
// Query params and path params deal in this function.
func parserRequestURL(c *Client, req *Request) error {
splitURL := strings.Split(req.url, "?")
// I don't want to judge splitURL length.
splitURL = append(splitURL, "")
// Determine whether to superimpose baseurl based on
// whether the URL starts with the protocol
uri := splitURL[0]
if !protocolCheck.MatchString(uri) {
uri = c.baseURL + uri
if !protocolCheck.MatchString(uri) {
return ErrURLFormat
}
}
// set path params
req.path.VisitAll(func(key, val string) {
uri = strings.ReplaceAll(uri, ":"+key, val)
})
c.path.VisitAll(func(key, val string) {
uri = strings.ReplaceAll(uri, ":"+key, val)
})
// set uri to request and other related setting
req.RawRequest.SetRequestURI(uri)
// merge query params
hashSplit := strings.Split(splitURL[1], "#")
hashSplit = append(hashSplit, "")
args := fasthttp.AcquireArgs()
defer func() {
fasthttp.ReleaseArgs(args)
}()
args.Parse(hashSplit[0])
c.params.VisitAll(func(key, value []byte) {
args.AddBytesKV(key, value)
})
req.params.VisitAll(func(key, value []byte) {
args.AddBytesKV(key, value)
})
req.RawRequest.URI().SetQueryStringBytes(utils.CopyBytes(args.QueryString()))
req.RawRequest.URI().SetHash(hashSplit[1])
return nil
}
// parserRequestHeader will make request header up.
// It will merge headers from client and request.
// Header should be set automatically based on data.
// User-Agent should be set.
func parserRequestHeader(c *Client, req *Request) error {
// set method
req.RawRequest.Header.SetMethod(req.Method())
// merge header
c.header.VisitAll(func(key, value []byte) {
req.RawRequest.Header.AddBytesKV(key, value)
})
req.header.VisitAll(func(key, value []byte) {
req.RawRequest.Header.AddBytesKV(key, value)
})
// according to data set content-type
switch req.bodyType {
case jsonBody:
req.RawRequest.Header.SetContentType(applicationJSON)
req.RawRequest.Header.Set(headerAccept, applicationJSON)
case xmlBody:
req.RawRequest.Header.SetContentType(applicationXML)
case formBody:
req.RawRequest.Header.SetContentType(applicationForm)
case filesBody:
req.RawRequest.Header.SetContentType(multipartFormData)
// set boundary
if req.boundary == boundary {
req.boundary += randString(16)
}
req.RawRequest.Header.SetMultipartFormBoundary(req.boundary)
default:
}
// set useragent
req.RawRequest.Header.SetUserAgent(defaultUserAgent)
if c.userAgent != "" {
req.RawRequest.Header.SetUserAgent(c.userAgent)
}
if req.userAgent != "" {
req.RawRequest.Header.SetUserAgent(req.userAgent)
}
// set referer
req.RawRequest.Header.SetReferer(c.referer)
if req.referer != "" {
req.RawRequest.Header.SetReferer(req.referer)
}
// set cookie
// add cookie form jar to req
if c.cookieJar != nil {
c.cookieJar.dumpCookiesToReq(req.RawRequest)
}
c.cookies.VisitAll(func(key, val string) {
req.RawRequest.Header.SetCookie(key, val)
})
req.cookies.VisitAll(func(key, val string) {
req.RawRequest.Header.SetCookie(key, val)
})
return nil
}
// parserRequestBody automatically serializes the data according to
// the data type and stores it in the body of the rawRequest
func parserRequestBody(c *Client, req *Request) error {
switch req.bodyType {
case jsonBody:
body, err := c.jsonMarshal(req.body)
if err != nil {
return err
}
req.RawRequest.SetBody(body)
case xmlBody:
body, err := c.xmlMarshal(req.body)
if err != nil {
return err
}
req.RawRequest.SetBody(body)
case formBody:
req.RawRequest.SetBody(req.formData.QueryString())
case filesBody:
return parserRequestBodyFile(req)
case rawBody:
if body, ok := req.body.([]byte); ok {
req.RawRequest.SetBody(body)
} else {
return ErrBodyType
}
case noBody:
return nil
}
return nil
}
// parserRequestBodyFile parses request body if body type is file
// this is an addition of parserRequestBody.
func parserRequestBodyFile(req *Request) error {
mw := multipart.NewWriter(req.RawRequest.BodyWriter())
err := mw.SetBoundary(req.boundary)
if err != nil {
return fmt.Errorf("set boundary error: %w", err)
}
defer func() {
err := mw.Close()
if err != nil {
return
}
}()
// add formdata
req.formData.VisitAll(func(key, value []byte) {
if err != nil {
return
}
err = mw.WriteField(utils.UnsafeString(key), utils.UnsafeString(value))
})
if err != nil {
return fmt.Errorf("write formdata error: %w", err)
}
// add file
b := make([]byte, 512)
for i, v := range req.files {
if v.name == "" && v.path == "" {
return ErrFileNoName
}
// if name is not exist, set name
if v.name == "" && v.path != "" {
v.path = filepath.Clean(v.path)
v.name = filepath.Base(v.path)
}
// if field name is not exist, set it
if v.fieldName == "" {
v.fieldName = "file" + strconv.Itoa(i+1)
}
// check the reader
if v.reader == nil {
v.reader, err = os.Open(v.path)
if err != nil {
return fmt.Errorf("open file error: %w", err)
}
}
// write file
w, err := mw.CreateFormFile(v.fieldName, v.name)
if err != nil {
return fmt.Errorf("create file error: %w", err)
}
for {
n, err := v.reader.Read(b)
if err != nil && !errors.Is(err, io.EOF) {
return fmt.Errorf("read file error: %w", err)
}
if errors.Is(err, io.EOF) {
break
}
_, err = w.Write(b[:n])
if err != nil {
return fmt.Errorf("write file error: %w", err)
}
}
err = v.reader.Close()
if err != nil {
return fmt.Errorf("close file error: %w", err)
}
}
return nil
}
// parserResponseHeader will parse the response header and store it in the response
func parserResponseCookie(c *Client, resp *Response, req *Request) error {
var err error
resp.RawResponse.Header.VisitAllCookie(func(key, value []byte) {
cookie := fasthttp.AcquireCookie()
err = cookie.ParseBytes(value)
if err != nil {
return
}
cookie.SetKeyBytes(key)
resp.cookie = append(resp.cookie, cookie)
})
if err != nil {
return err
}
// store cookies to jar
if c.cookieJar != nil {
c.cookieJar.parseCookiesFromResp(req.RawRequest.URI().Host(), req.RawRequest.URI().Path(), resp.RawResponse)
}
return nil
}
// logger is a response hook that logs the request and response
func logger(c *Client, resp *Response, req *Request) error {
if !c.debug {
return nil
}
c.logger.Debugf("%s\n", req.RawRequest.String())
c.logger.Debugf("%s\n", resp.RawResponse.String())
return nil
}

652
client/hooks_test.go Normal file
View File

@ -0,0 +1,652 @@
package client
import (
"bytes"
"encoding/xml"
"fmt"
"io"
"net"
"net/url"
"strings"
"testing"
"github.com/gofiber/fiber/v3"
"github.com/stretchr/testify/require"
)
func Test_Rand_String(t *testing.T) {
t.Parallel()
tests := []struct {
name string
args int
}{
{
name: "test generate",
args: 16,
},
{
name: "test generate smaller string",
args: 8,
},
{
name: "test generate larger string",
args: 32,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := randString(tt.args)
require.Len(t, got, tt.args)
})
}
}
func Test_Parser_Request_URL(t *testing.T) {
t.Parallel()
t.Run("client baseurl should be set", func(t *testing.T) {
t.Parallel()
client := NewClient().SetBaseURL("http://example.com/api")
req := AcquireRequest().SetURL("")
err := parserRequestURL(client, req)
require.NoError(t, err)
require.Equal(t, "http://example.com/api", req.RawRequest.URI().String())
})
t.Run("request url should be set", func(t *testing.T) {
t.Parallel()
client := NewClient()
req := AcquireRequest().SetURL("http://example.com/api")
err := parserRequestURL(client, req)
require.NoError(t, err)
require.Equal(t, "http://example.com/api", req.RawRequest.URI().String())
})
t.Run("the request url will override baseurl with protocol", func(t *testing.T) {
t.Parallel()
client := NewClient().SetBaseURL("http://example.com/api")
req := AcquireRequest().SetURL("http://example.com/api/v1")
err := parserRequestURL(client, req)
require.NoError(t, err)
require.Equal(t, "http://example.com/api/v1", req.RawRequest.URI().String())
})
t.Run("the request url should be append after baseurl without protocol", func(t *testing.T) {
t.Parallel()
client := NewClient().SetBaseURL("http://example.com/api")
req := AcquireRequest().SetURL("/v1")
err := parserRequestURL(client, req)
require.NoError(t, err)
require.Equal(t, "http://example.com/api/v1", req.RawRequest.URI().String())
})
t.Run("the url is error", func(t *testing.T) {
t.Parallel()
client := NewClient().SetBaseURL("example.com/api")
req := AcquireRequest().SetURL("/v1")
err := parserRequestURL(client, req)
require.Equal(t, ErrURLFormat, err)
})
t.Run("the path param from client", func(t *testing.T) {
t.Parallel()
client := NewClient().
SetBaseURL("http://example.com/api/:id").
SetPathParam("id", "5")
req := AcquireRequest()
err := parserRequestURL(client, req)
require.NoError(t, err)
require.Equal(t, "http://example.com/api/5", req.RawRequest.URI().String())
})
t.Run("the path param from request", func(t *testing.T) {
t.Parallel()
client := NewClient().
SetBaseURL("http://example.com/api/:id/:name").
SetPathParam("id", "5")
req := AcquireRequest().
SetURL("/{key}").
SetPathParams(map[string]string{
"name": "fiber",
"key": "val",
}).
DelPathParams("key")
err := parserRequestURL(client, req)
require.NoError(t, err)
require.Equal(t, "http://example.com/api/5/fiber/%7Bkey%7D", req.RawRequest.URI().String())
})
t.Run("the path param from request and client", func(t *testing.T) {
t.Parallel()
client := NewClient().
SetBaseURL("http://example.com/api/:id/:name").
SetPathParam("id", "5")
req := AcquireRequest().
SetURL("/:key").
SetPathParams(map[string]string{
"name": "fiber",
"key": "val",
"id": "12",
})
err := parserRequestURL(client, req)
require.NoError(t, err)
require.Equal(t, "http://example.com/api/12/fiber/val", req.RawRequest.URI().String())
})
t.Run("query params from client should be set", func(t *testing.T) {
t.Parallel()
client := NewClient().
SetParam("foo", "bar")
req := AcquireRequest().SetURL("http://example.com/api/v1")
err := parserRequestURL(client, req)
require.NoError(t, err)
require.Equal(t, []byte("foo=bar"), req.RawRequest.URI().QueryString())
})
t.Run("query params from request should be set", func(t *testing.T) {
t.Parallel()
client := NewClient()
req := AcquireRequest().
SetURL("http://example.com/api/v1").
SetParam("bar", "foo")
err := parserRequestURL(client, req)
require.NoError(t, err)
require.Equal(t, []byte("bar=foo"), req.RawRequest.URI().QueryString())
})
t.Run("query params should be merged", func(t *testing.T) {
t.Parallel()
client := NewClient().
SetParam("bar", "foo1")
req := AcquireRequest().
SetURL("http://example.com/api/v1?bar=foo2").
SetParam("bar", "foo")
err := parserRequestURL(client, req)
require.NoError(t, err)
values, err := url.ParseQuery(string(req.RawRequest.URI().QueryString()))
require.NoError(t, err)
flag1, flag2, flag3 := false, false, false
for _, v := range values["bar"] {
if v == "foo1" {
flag1 = true
} else if v == "foo2" {
flag2 = true
} else if v == "foo" {
flag3 = true
}
}
require.True(t, flag1)
require.True(t, flag2)
require.True(t, flag3)
})
}
func Test_Parser_Request_Header(t *testing.T) {
t.Parallel()
t.Run("client header should be set", func(t *testing.T) {
t.Parallel()
client := NewClient().
SetHeaders(map[string]string{
fiber.HeaderContentType: "application/json",
})
req := AcquireRequest()
err := parserRequestHeader(client, req)
require.NoError(t, err)
require.Equal(t, []byte("application/json"), req.RawRequest.Header.ContentType())
})
t.Run("request header should be set", func(t *testing.T) {
t.Parallel()
client := NewClient()
req := AcquireRequest().
SetHeaders(map[string]string{
fiber.HeaderContentType: "application/json, utf-8",
})
err := parserRequestHeader(client, req)
require.NoError(t, err)
require.Equal(t, []byte("application/json, utf-8"), req.RawRequest.Header.ContentType())
})
t.Run("request header should override client header", func(t *testing.T) {
t.Parallel()
client := NewClient().
SetHeader(fiber.HeaderContentType, "application/xml")
req := AcquireRequest().
SetHeader(fiber.HeaderContentType, "application/json, utf-8")
err := parserRequestHeader(client, req)
require.NoError(t, err)
require.Equal(t, []byte("application/json, utf-8"), req.RawRequest.Header.ContentType())
})
t.Run("auto set json header", func(t *testing.T) {
t.Parallel()
type jsonData struct {
Name string `json:"name"`
}
client := NewClient()
req := AcquireRequest().
SetJSON(jsonData{
Name: "foo",
})
err := parserRequestHeader(client, req)
require.NoError(t, err)
require.Equal(t, []byte(applicationJSON), req.RawRequest.Header.ContentType())
})
t.Run("auto set xml header", func(t *testing.T) {
t.Parallel()
type xmlData struct {
XMLName xml.Name `xml:"body"`
Name string `xml:"name"`
}
client := NewClient()
req := AcquireRequest().
SetXML(xmlData{
Name: "foo",
})
err := parserRequestHeader(client, req)
require.NoError(t, err)
require.Equal(t, []byte(applicationXML), req.RawRequest.Header.ContentType())
})
t.Run("auto set form data header", func(t *testing.T) {
t.Parallel()
client := NewClient()
req := AcquireRequest().
SetFormDatas(map[string]string{
"foo": "bar",
"ball": "cricle and square",
})
err := parserRequestHeader(client, req)
require.NoError(t, err)
require.Equal(t, applicationForm, string(req.RawRequest.Header.ContentType()))
})
t.Run("auto set file header", func(t *testing.T) {
t.Parallel()
client := NewClient()
req := AcquireRequest().
AddFileWithReader("hello", io.NopCloser(strings.NewReader("world"))).
SetFormData("foo", "bar")
err := parserRequestHeader(client, req)
require.NoError(t, err)
require.True(t, strings.Contains(string(req.RawRequest.Header.MultipartFormBoundary()), "--FiberFormBoundary"))
require.True(t, strings.Contains(string(req.RawRequest.Header.ContentType()), multipartFormData))
})
t.Run("ua should have default value", func(t *testing.T) {
t.Parallel()
client := NewClient()
req := AcquireRequest()
err := parserRequestHeader(client, req)
require.NoError(t, err)
require.Equal(t, []byte("fiber"), req.RawRequest.Header.UserAgent())
})
t.Run("ua in client should be set", func(t *testing.T) {
t.Parallel()
client := NewClient().SetUserAgent("foo")
req := AcquireRequest()
err := parserRequestHeader(client, req)
require.NoError(t, err)
require.Equal(t, []byte("foo"), req.RawRequest.Header.UserAgent())
})
t.Run("ua in request should have higher level", func(t *testing.T) {
t.Parallel()
client := NewClient().SetUserAgent("foo")
req := AcquireRequest().SetUserAgent("bar")
err := parserRequestHeader(client, req)
require.NoError(t, err)
require.Equal(t, []byte("bar"), req.RawRequest.Header.UserAgent())
})
t.Run("referer in client should be set", func(t *testing.T) {
t.Parallel()
client := NewClient().SetReferer("https://example.com")
req := AcquireRequest()
err := parserRequestHeader(client, req)
require.NoError(t, err)
require.Equal(t, []byte("https://example.com"), req.RawRequest.Header.Referer())
})
t.Run("referer in request should have higher level", func(t *testing.T) {
t.Parallel()
client := NewClient().SetReferer("http://example.com")
req := AcquireRequest().SetReferer("https://example.com")
err := parserRequestHeader(client, req)
require.NoError(t, err)
require.Equal(t, []byte("https://example.com"), req.RawRequest.Header.Referer())
})
t.Run("client cookie should be set", func(t *testing.T) {
t.Parallel()
client := NewClient().
SetCookie("foo", "bar").
SetCookies(map[string]string{
"bar": "foo",
"bar1": "foo1",
}).
DelCookies("bar1")
req := AcquireRequest()
err := parserRequestHeader(client, req)
require.NoError(t, err)
require.Equal(t, "bar", string(req.RawRequest.Header.Cookie("foo")))
require.Equal(t, "foo", string(req.RawRequest.Header.Cookie("bar")))
require.Equal(t, "", string(req.RawRequest.Header.Cookie("bar1")))
})
t.Run("request cookie should be set", func(t *testing.T) {
t.Parallel()
type cookies struct {
Foo string `cookie:"foo"`
Bar int `cookie:"bar"`
}
client := NewClient()
req := AcquireRequest().
SetCookiesWithStruct(&cookies{
Foo: "bar",
Bar: 67,
})
err := parserRequestHeader(client, req)
require.NoError(t, err)
require.Equal(t, "bar", string(req.RawRequest.Header.Cookie("foo")))
require.Equal(t, "67", string(req.RawRequest.Header.Cookie("bar")))
require.Equal(t, "", string(req.RawRequest.Header.Cookie("bar1")))
})
t.Run("request cookie will override client cookie", func(t *testing.T) {
t.Parallel()
type cookies struct {
Foo string `cookie:"foo"`
Bar int `cookie:"bar"`
}
client := NewClient().
SetCookie("foo", "bar").
SetCookies(map[string]string{
"bar": "foo",
"bar1": "foo1",
})
req := AcquireRequest().
SetCookiesWithStruct(&cookies{
Foo: "bar",
Bar: 67,
})
err := parserRequestHeader(client, req)
require.NoError(t, err)
require.Equal(t, "bar", string(req.RawRequest.Header.Cookie("foo")))
require.Equal(t, "67", string(req.RawRequest.Header.Cookie("bar")))
require.Equal(t, "foo1", string(req.RawRequest.Header.Cookie("bar1")))
})
}
func Test_Parser_Request_Body(t *testing.T) {
t.Parallel()
t.Run("json body", func(t *testing.T) {
t.Parallel()
type jsonData struct {
Name string `json:"name"`
}
client := NewClient()
req := AcquireRequest().
SetJSON(jsonData{
Name: "foo",
})
err := parserRequestBody(client, req)
require.NoError(t, err)
require.Equal(t, []byte("{\"name\":\"foo\"}"), req.RawRequest.Body())
})
t.Run("xml body", func(t *testing.T) {
t.Parallel()
type xmlData struct {
XMLName xml.Name `xml:"body"`
Name string `xml:"name"`
}
client := NewClient()
req := AcquireRequest().
SetXML(xmlData{
Name: "foo",
})
err := parserRequestBody(client, req)
require.NoError(t, err)
require.Equal(t, []byte("<body><name>foo</name></body>"), req.RawRequest.Body())
})
t.Run("form data body", func(t *testing.T) {
t.Parallel()
client := NewClient()
req := AcquireRequest().
SetFormDatas(map[string]string{
"ball": "cricle and square",
})
err := parserRequestBody(client, req)
require.NoError(t, err)
require.Equal(t, "ball=cricle+and+square", string(req.RawRequest.Body()))
})
t.Run("form data body error", func(t *testing.T) {
t.Parallel()
client := NewClient()
req := AcquireRequest().
SetFormDatas(map[string]string{
"": "",
})
err := parserRequestBody(client, req)
require.NoError(t, err)
})
t.Run("file body", func(t *testing.T) {
t.Parallel()
client := NewClient()
req := AcquireRequest().
AddFileWithReader("hello", io.NopCloser(strings.NewReader("world")))
err := parserRequestBody(client, req)
require.NoError(t, err)
require.True(t, strings.Contains(string(req.RawRequest.Body()), "----FiberFormBoundary"))
require.True(t, strings.Contains(string(req.RawRequest.Body()), "world"))
})
t.Run("file and form data", func(t *testing.T) {
t.Parallel()
client := NewClient()
req := AcquireRequest().
AddFileWithReader("hello", io.NopCloser(strings.NewReader("world"))).
SetFormData("foo", "bar")
err := parserRequestBody(client, req)
require.NoError(t, err)
require.True(t, strings.Contains(string(req.RawRequest.Body()), "----FiberFormBoundary"))
require.True(t, strings.Contains(string(req.RawRequest.Body()), "world"))
require.True(t, strings.Contains(string(req.RawRequest.Body()), "bar"))
})
t.Run("raw body", func(t *testing.T) {
t.Parallel()
client := NewClient()
req := AcquireRequest().
SetRawBody([]byte("hello world"))
err := parserRequestBody(client, req)
require.NoError(t, err)
require.Equal(t, []byte("hello world"), req.RawRequest.Body())
})
t.Run("raw body error", func(t *testing.T) {
t.Parallel()
client := NewClient()
req := AcquireRequest().
SetRawBody([]byte("hello world"))
req.body = nil
err := parserRequestBody(client, req)
require.ErrorIs(t, err, ErrBodyType)
})
}
type dummyLogger struct {
buf *bytes.Buffer
}
func (*dummyLogger) Trace(_ ...any) {}
func (*dummyLogger) Debug(_ ...any) {}
func (*dummyLogger) Info(_ ...any) {}
func (*dummyLogger) Warn(_ ...any) {}
func (*dummyLogger) Error(_ ...any) {}
func (*dummyLogger) Fatal(_ ...any) {}
func (*dummyLogger) Panic(_ ...any) {}
func (*dummyLogger) Tracef(_ string, _ ...any) {}
func (l *dummyLogger) Debugf(format string, v ...any) {
_, _ = l.buf.WriteString(fmt.Sprintf(format, v...)) //nolint:errcheck // not needed
}
func (*dummyLogger) Infof(_ string, _ ...any) {}
func (*dummyLogger) Warnf(_ string, _ ...any) {}
func (*dummyLogger) Errorf(_ string, _ ...any) {}
func (*dummyLogger) Fatalf(_ string, _ ...any) {}
func (*dummyLogger) Panicf(_ string, _ ...any) {}
func (*dummyLogger) Tracew(_ string, _ ...any) {}
func (*dummyLogger) Debugw(_ string, _ ...any) {}
func (*dummyLogger) Infow(_ string, _ ...any) {}
func (*dummyLogger) Warnw(_ string, _ ...any) {}
func (*dummyLogger) Errorw(_ string, _ ...any) {}
func (*dummyLogger) Fatalw(_ string, _ ...any) {}
func (*dummyLogger) Panicw(_ string, _ ...any) {}
func Test_Client_Logger_Debug(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("response")
})
addrChan := make(chan string)
go func() {
require.NoError(t, app.Listen(":0", fiber.ListenConfig{
DisableStartupMessage: true,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}))
}()
defer func(app *fiber.App) {
require.NoError(t, app.Shutdown())
}(app)
var buf bytes.Buffer
logger := &dummyLogger{buf: &buf}
client := NewClient()
client.Debug().SetLogger(logger)
addr := <-addrChan
resp, err := client.Get("http://" + addr)
require.NoError(t, err)
defer resp.Close()
require.NoError(t, err)
require.Contains(t, buf.String(), "Host: "+addr)
require.Contains(t, buf.String(), "Content-Length: 8")
}
func Test_Client_Logger_DisableDebug(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("response")
})
addrChan := make(chan string)
go func() {
require.NoError(t, app.Listen(":0", fiber.ListenConfig{
DisableStartupMessage: true,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}))
}()
defer func(app *fiber.App) {
require.NoError(t, app.Shutdown())
}(app)
var buf bytes.Buffer
logger := &dummyLogger{buf: &buf}
client := NewClient()
client.DisableDebug().SetLogger(logger)
addr := <-addrChan
resp, err := client.Get("http://" + addr)
require.NoError(t, err)
defer resp.Close()
require.NoError(t, err)
require.Empty(t, buf.String())
}

985
client/request.go Normal file
View File

@ -0,0 +1,985 @@
package client
import (
"bytes"
"context"
"errors"
"io"
"path/filepath"
"reflect"
"strconv"
"sync"
"time"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp"
)
// WithStruct Implementing this interface allows data to
// be stored from a struct via reflect.
type WithStruct interface {
Add(name, obj string)
Del(name string)
}
// Types of request bodies.
type bodyType int
// Enumeration definition of the request body type.
const (
noBody bodyType = iota
jsonBody
xmlBody
formBody
filesBody
rawBody
)
var ErrClientNil = errors.New("client can not be nil")
// Request is a struct which contains the request data.
type Request struct {
url string
method string
userAgent string
boundary string
referer string
ctx context.Context //nolint:containedctx // It's needed to be stored in the request.
header *Header
params *QueryParam
cookies *Cookie
path *PathParam
timeout time.Duration
maxRedirects int
client *Client
body any
formData *FormData
files []*File
bodyType bodyType
RawRequest *fasthttp.Request
}
// Method returns http method in request.
func (r *Request) Method() string {
return r.method
}
// SetMethod will set method for Request object,
// user should use request method to set method.
func (r *Request) SetMethod(method string) *Request {
r.method = method
return r
}
// URL returns request url in Request instance.
func (r *Request) URL() string {
return r.url
}
// SetURL will set url for Request object.
func (r *Request) SetURL(url string) *Request {
r.url = url
return r
}
// Client get Client instance in Request.
func (r *Request) Client() *Client {
return r.client
}
// SetClient method sets client in request instance.
func (r *Request) SetClient(c *Client) *Request {
if c == nil {
panic(ErrClientNil)
}
r.client = c
return r
}
// Context returns the Context if its already set in request
// otherwise it creates new one using `context.Background()`.
func (r *Request) Context() context.Context {
if r.ctx == nil {
return context.Background()
}
return r.ctx
}
// SetContext sets the context.Context for current Request. It allows
// to interrupt the request execution if ctx.Done() channel is closed.
// See https://blog.golang.org/context article and the "context" package
// documentation.
func (r *Request) SetContext(ctx context.Context) *Request {
r.ctx = ctx
return r
}
// Header method returns header value via key,
// this method will visit all field in the header,
// then sort them.
func (r *Request) Header(key string) []string {
return r.header.PeekMultiple(key)
}
// AddHeader method adds a single header field and its value in the request instance.
// It will override header which set in client instance.
func (r *Request) AddHeader(key, val string) *Request {
r.header.Add(key, val)
return r
}
// SetHeader method sets a single header field and its value in the request instance.
// It will override header which set in client instance.
func (r *Request) SetHeader(key, val string) *Request {
r.header.Del(key)
r.header.Set(key, val)
return r
}
// AddHeaders method adds multiple header fields and its values at one go in the request instance.
// It will override header which set in client instance.
func (r *Request) AddHeaders(h map[string][]string) *Request {
r.header.AddHeaders(h)
return r
}
// SetHeaders method sets multiple header fields and its values at one go in the request instance.
// It will override header which set in client instance.
func (r *Request) SetHeaders(h map[string]string) *Request {
r.header.SetHeaders(h)
return r
}
// Param method returns params value via key,
// this method will visit all field in the query param.
func (r *Request) Param(key string) []string {
var res []string
tmp := r.params.PeekMulti(key)
for _, v := range tmp {
res = append(res, utils.UnsafeString(v))
}
return res
}
// AddParam method adds a single param field and its value in the request instance.
// It will override param which set in client instance.
func (r *Request) AddParam(key, val string) *Request {
r.params.Add(key, val)
return r
}
// SetParam method sets a single param field and its value in the request instance.
// It will override param which set in client instance.
func (r *Request) SetParam(key, val string) *Request {
r.params.Set(key, val)
return r
}
// AddParams method adds multiple param fields and its values at one go in the request instance.
// It will override param which set in client instance.
func (r *Request) AddParams(m map[string][]string) *Request {
r.params.AddParams(m)
return r
}
// SetParams method sets multiple param fields and its values at one go in the request instance.
// It will override param which set in client instance.
func (r *Request) SetParams(m map[string]string) *Request {
r.params.SetParams(m)
return r
}
// SetParamsWithStruct method sets multiple param fields and its values at one go in the request instance.
// It will override param which set in client instance.
func (r *Request) SetParamsWithStruct(v any) *Request {
r.params.SetParamsWithStruct(v)
return r
}
// DelParams method deletes single or multiple param fields ant its values.
func (r *Request) DelParams(key ...string) *Request {
for _, v := range key {
r.params.Del(v)
}
return r
}
// UserAgent returns user agent in request instance.
func (r *Request) UserAgent() string {
return r.userAgent
}
// SetUserAgent method sets user agent in request.
// It will override user agent which set in client instance.
func (r *Request) SetUserAgent(ua string) *Request {
r.userAgent = ua
return r
}
// Boundary returns boundary in multipart boundary.
func (r *Request) Boundary() string {
return r.boundary
}
// SetBoundary method sets multipart boundary.
func (r *Request) SetBoundary(b string) *Request {
r.boundary = b
return r
}
// Referer returns referer in request instance.
func (r *Request) Referer() string {
return r.referer
}
// SetReferer method sets referer in request.
// It will override referer which set in client instance.
func (r *Request) SetReferer(referer string) *Request {
r.referer = referer
return r
}
// Cookie returns the cookie be set in request instance.
// if cookie doesn't exist, return empty string.
func (r *Request) Cookie(key string) string {
if val, ok := (*r.cookies)[key]; ok {
return val
}
return ""
}
// SetCookie method sets a single cookie field and its value in the request instance.
// It will override cookie which set in client instance.
func (r *Request) SetCookie(key, val string) *Request {
r.cookies.SetCookie(key, val)
return r
}
// SetCookies method sets multiple cookie fields and its values at one go in the request instance.
// It will override cookie which set in client instance.
func (r *Request) SetCookies(m map[string]string) *Request {
r.cookies.SetCookies(m)
return r
}
// SetCookiesWithStruct method sets multiple cookie fields and its values at one go in the request instance.
// It will override cookie which set in client instance.
func (r *Request) SetCookiesWithStruct(v any) *Request {
r.cookies.SetCookiesWithStruct(v)
return r
}
// DelCookies method deletes single or multiple cookie fields ant its values.
func (r *Request) DelCookies(key ...string) *Request {
r.cookies.DelCookies(key...)
return r
}
// PathParam returns the path param be set in request instance.
// if path param doesn't exist, return empty string.
func (r *Request) PathParam(key string) string {
if val, ok := (*r.path)[key]; ok {
return val
}
return ""
}
// SetPathParam method sets a single path param field and its value in the request instance.
// It will override path param which set in client instance.
func (r *Request) SetPathParam(key, val string) *Request {
r.path.SetParam(key, val)
return r
}
// SetPathParams method sets multiple path param fields and its values at one go in the request instance.
// It will override path param which set in client instance.
func (r *Request) SetPathParams(m map[string]string) *Request {
r.path.SetParams(m)
return r
}
// SetPathParamsWithStruct method sets multiple path param fields and its values at one go in the request instance.
// It will override path param which set in client instance.
func (r *Request) SetPathParamsWithStruct(v any) *Request {
r.path.SetParamsWithStruct(v)
return r
}
// DelPathParams method deletes single or multiple path param fields ant its values.
func (r *Request) DelPathParams(key ...string) *Request {
r.path.DelParams(key...)
return r
}
// ResetPathParams deletes all path params.
func (r *Request) ResetPathParams() *Request {
r.path.Reset()
return r
}
// SetJSON method sets json body in request.
func (r *Request) SetJSON(v any) *Request {
r.body = v
r.bodyType = jsonBody
return r
}
// SetXML method sets xml body in request.
func (r *Request) SetXML(v any) *Request {
r.body = v
r.bodyType = xmlBody
return r
}
// SetRawBody method sets body with raw data in request.
func (r *Request) SetRawBody(v []byte) *Request {
r.body = v
r.bodyType = rawBody
return r
}
// resetBody will clear body object and set bodyType
// if body type is formBody and filesBody, the new body type will be ignored.
func (r *Request) resetBody(t bodyType) {
r.body = nil
// Set form data after set file ignore.
if r.bodyType == filesBody && t == formBody {
return
}
r.bodyType = t
}
// FormData method returns form data value via key,
// this method will visit all field in the form data.
func (r *Request) FormData(key string) []string {
var res []string
tmp := r.formData.PeekMulti(key)
for _, v := range tmp {
res = append(res, utils.UnsafeString(v))
}
return res
}
// AddFormData method adds a single form data field and its value in the request instance.
func (r *Request) AddFormData(key, val string) *Request {
r.formData.AddData(key, val)
r.resetBody(formBody)
return r
}
// SetFormData method sets a single form data field and its value in the request instance.
func (r *Request) SetFormData(key, val string) *Request {
r.formData.SetData(key, val)
r.resetBody(formBody)
return r
}
// AddFormDatas method adds multiple form data fields and its values in the request instance.
func (r *Request) AddFormDatas(m map[string][]string) *Request {
r.formData.AddDatas(m)
r.resetBody(formBody)
return r
}
// SetFormDatas method sets multiple form data fields and its values in the request instance.
func (r *Request) SetFormDatas(m map[string]string) *Request {
r.formData.SetDatas(m)
r.resetBody(formBody)
return r
}
// SetFormDatasWithStruct method sets multiple form data fields
// and its values in the request instance via struct.
func (r *Request) SetFormDatasWithStruct(v any) *Request {
r.formData.SetDatasWithStruct(v)
r.resetBody(formBody)
return r
}
// DelFormDatas method deletes multiple form data fields and its value in the request instance.
func (r *Request) DelFormDatas(key ...string) *Request {
r.formData.DelDatas(key...)
r.resetBody(formBody)
return r
}
// File returns file ptr store in request obj by name.
// If name field is empty, it will try to match path.
func (r *Request) File(name string) *File {
for _, v := range r.files {
if v.name == "" {
if filepath.Base(v.path) == name {
return v
}
} else if v.name == name {
return v
}
}
return nil
}
// FileByPath returns file ptr store in request obj by path.
func (r *Request) FileByPath(path string) *File {
for _, v := range r.files {
if v.path == path {
return v
}
}
return nil
}
// AddFile method adds single file field
// and its value in the request instance via file path.
func (r *Request) AddFile(path string) *Request {
r.files = append(r.files, AcquireFile(SetFilePath(path)))
r.resetBody(filesBody)
return r
}
// AddFileWithReader method adds single field
// and its value in the request instance via reader.
func (r *Request) AddFileWithReader(name string, reader io.ReadCloser) *Request {
r.files = append(r.files, AcquireFile(SetFileName(name), SetFileReader(reader)))
r.resetBody(filesBody)
return r
}
// AddFiles method adds multiple file fields
// and its value in the request instance via File instance.
func (r *Request) AddFiles(files ...*File) *Request {
r.files = append(r.files, files...)
r.resetBody(filesBody)
return r
}
// Timeout returns the length of timeout in request.
func (r *Request) Timeout() time.Duration {
return r.timeout
}
// SetTimeout method sets timeout field and its values at one go in the request instance.
// It will override timeout which set in client instance.
func (r *Request) SetTimeout(t time.Duration) *Request {
r.timeout = t
return r
}
// MaxRedirects returns the max redirects count in request.
func (r *Request) MaxRedirects() int {
return r.maxRedirects
}
// SetMaxRedirects method sets the maximum number of redirects at one go in the request instance.
// It will override max redirect which set in client instance.
func (r *Request) SetMaxRedirects(count int) *Request {
r.maxRedirects = count
return r
}
// checkClient method checks whether the client has been set in request.
func (r *Request) checkClient() {
if r.client == nil {
r.SetClient(defaultClient)
}
}
// Get Send get request.
func (r *Request) Get(url string) (*Response, error) {
return r.SetURL(url).SetMethod(fiber.MethodGet).Send()
}
// Post Send post request.
func (r *Request) Post(url string) (*Response, error) {
return r.SetURL(url).SetMethod(fiber.MethodPost).Send()
}
// Head Send head request.
func (r *Request) Head(url string) (*Response, error) {
return r.SetURL(url).SetMethod(fiber.MethodHead).Send()
}
// Put Send put request.
func (r *Request) Put(url string) (*Response, error) {
return r.SetURL(url).SetMethod(fiber.MethodPut).Send()
}
// Delete Send Delete request.
func (r *Request) Delete(url string) (*Response, error) {
return r.SetURL(url).SetMethod(fiber.MethodDelete).Send()
}
// Options Send Options request.
func (r *Request) Options(url string) (*Response, error) {
return r.SetURL(url).SetMethod(fiber.MethodOptions).Send()
}
// Patch Send patch request.
func (r *Request) Patch(url string) (*Response, error) {
return r.SetURL(url).SetMethod(fiber.MethodPatch).Send()
}
// Custom Send custom request.
func (r *Request) Custom(url, method string) (*Response, error) {
return r.SetURL(url).SetMethod(method).Send()
}
// Send a request.
func (r *Request) Send() (*Response, error) {
r.checkClient()
return newCore().execute(r.Context(), r.Client(), r)
}
// Reset clear Request object, used by ReleaseRequest method.
func (r *Request) Reset() {
r.url = ""
r.method = fiber.MethodGet
r.userAgent = ""
r.referer = ""
r.ctx = nil
r.body = nil
r.timeout = 0
r.maxRedirects = 0
r.bodyType = noBody
r.boundary = boundary
for len(r.files) != 0 {
t := r.files[0]
r.files = r.files[1:]
ReleaseFile(t)
}
r.formData.Reset()
r.path.Reset()
r.cookies.Reset()
r.header.Reset()
r.params.Reset()
r.RawRequest.Reset()
}
// Header is a wrapper which wrap http.Header,
// the header in client and request will store in it.
type Header struct {
*fasthttp.RequestHeader
}
// PeekMultiple methods returns multiple field in header with same key.
func (h *Header) PeekMultiple(key string) []string {
var res []string
byteKey := []byte(key)
h.RequestHeader.VisitAll(func(key, value []byte) {
if bytes.EqualFold(key, byteKey) {
res = append(res, utils.UnsafeString(value))
}
})
return res
}
// AddHeaders receive a map and add each value to header.
func (h *Header) AddHeaders(r map[string][]string) {
for k, v := range r {
for _, vv := range v {
h.Add(k, vv)
}
}
}
// SetHeaders will override all headers.
func (h *Header) SetHeaders(r map[string]string) {
for k, v := range r {
h.Del(k)
h.Set(k, v)
}
}
// QueryParam is a wrapper which wrap url.Values,
// the query string and formdata in client and request will store in it.
type QueryParam struct {
*fasthttp.Args
}
// AddParams receive a map and add each value to param.
func (p *QueryParam) AddParams(r map[string][]string) {
for k, v := range r {
for _, vv := range v {
p.Add(k, vv)
}
}
}
// SetParams will override all params.
func (p *QueryParam) SetParams(r map[string]string) {
for k, v := range r {
p.Set(k, v)
}
}
// SetParamsWithStruct will override all params with struct or pointer of struct.
// Now nested structs are not currently supported.
func (p *QueryParam) SetParamsWithStruct(v any) {
SetValWithStruct(p, "param", v)
}
// Cookie is a map which to store the cookies.
type Cookie map[string]string
// Add method impl the method in WithStruct interface.
func (c Cookie) Add(key, val string) {
c[key] = val
}
// Del method impl the method in WithStruct interface.
func (c Cookie) Del(key string) {
delete(c, key)
}
// SetCookie method sets a single val in Cookie.
func (c Cookie) SetCookie(key, val string) {
c[key] = val
}
// SetCookies method sets multiple val in Cookie.
func (c Cookie) SetCookies(m map[string]string) {
for k, v := range m {
c[k] = v
}
}
// SetCookiesWithStruct method sets multiple val in Cookie via a struct.
func (c Cookie) SetCookiesWithStruct(v any) {
SetValWithStruct(c, "cookie", v)
}
// DelCookies method deletes multiple val in Cookie.
func (c Cookie) DelCookies(key ...string) {
for _, v := range key {
c.Del(v)
}
}
// VisitAll method receive a function which can travel the all val.
func (c Cookie) VisitAll(f func(key, val string)) {
for k, v := range c {
f(k, v)
}
}
// Reset clear the Cookie object.
func (c Cookie) Reset() {
for k := range c {
delete(c, k)
}
}
// PathParam is a map which to store the cookies.
type PathParam map[string]string
// Add method impl the method in WithStruct interface.
func (p PathParam) Add(key, val string) {
p[key] = val
}
// Del method impl the method in WithStruct interface.
func (p PathParam) Del(key string) {
delete(p, key)
}
// SetParam method sets a single val in PathParam.
func (p PathParam) SetParam(key, val string) {
p[key] = val
}
// SetParams method sets multiple val in PathParam.
func (p PathParam) SetParams(m map[string]string) {
for k, v := range m {
p[k] = v
}
}
// SetParamsWithStruct method sets multiple val in PathParam via a struct.
func (p PathParam) SetParamsWithStruct(v any) {
SetValWithStruct(p, "path", v)
}
// DelParams method deletes multiple val in PathParams.
func (p PathParam) DelParams(key ...string) {
for _, v := range key {
p.Del(v)
}
}
// VisitAll method receive a function which can travel the all val.
func (p PathParam) VisitAll(f func(key, val string)) {
for k, v := range p {
f(k, v)
}
}
// Reset clear the PathParams object.
func (p PathParam) Reset() {
for k := range p {
delete(p, k)
}
}
// FormData is a wrapper of fasthttp.Args,
// and it be used for url encode body and file body.
type FormData struct {
*fasthttp.Args
}
// AddData method is a wrapper of Args's Add method.
func (f *FormData) AddData(key, val string) {
f.Add(key, val)
}
// SetData method is a wrapper of Args's Set method.
func (f *FormData) SetData(key, val string) {
f.Set(key, val)
}
// AddDatas method supports add multiple fields.
func (f *FormData) AddDatas(m map[string][]string) {
for k, v := range m {
for _, vv := range v {
f.Add(k, vv)
}
}
}
// SetDatas method supports set multiple fields.
func (f *FormData) SetDatas(m map[string]string) {
for k, v := range m {
f.Set(k, v)
}
}
// SetDatasWithStruct method supports set multiple fields via a struct.
func (f *FormData) SetDatasWithStruct(v any) {
SetValWithStruct(f, "form", v)
}
// DelDatas method deletes multiple fields.
func (f *FormData) DelDatas(key ...string) {
for _, v := range key {
f.Del(v)
}
}
// Reset clear the FormData object.
func (f *FormData) Reset() {
f.Args.Reset()
}
// File is a struct which support send files via request.
type File struct {
name string
fieldName string
path string
reader io.ReadCloser
}
// SetName method sets file name.
func (f *File) SetName(n string) {
f.name = n
}
// SetFieldName method sets key of file in the body.
func (f *File) SetFieldName(n string) {
f.fieldName = n
}
// SetPath method set file path.
func (f *File) SetPath(p string) {
f.path = p
}
// SetReader method can receive a io.ReadCloser
// which will be closed in parserBody hook.
func (f *File) SetReader(r io.ReadCloser) {
f.reader = r
}
// Reset clear the File object.
func (f *File) Reset() {
f.name = ""
f.fieldName = ""
f.path = ""
f.reader = nil
}
var requestPool = &sync.Pool{
New: func() any {
return &Request{
header: &Header{RequestHeader: &fasthttp.RequestHeader{}},
params: &QueryParam{Args: fasthttp.AcquireArgs()},
cookies: &Cookie{},
path: &PathParam{},
boundary: "--FiberFormBoundary",
formData: &FormData{Args: fasthttp.AcquireArgs()},
files: make([]*File, 0),
RawRequest: fasthttp.AcquireRequest(),
}
},
}
// AcquireRequest returns an empty request object from the pool.
//
// The returned request may be returned to the pool with ReleaseRequest when no longer needed.
// This allows reducing GC load.
func AcquireRequest() *Request {
req, ok := requestPool.Get().(*Request)
if !ok {
panic(errors.New("failed to type-assert to *Request"))
}
return req
}
// ReleaseRequest returns the object acquired via AcquireRequest to the pool.
//
// Do not access the released Request object, otherwise data races may occur.
func ReleaseRequest(req *Request) {
req.Reset()
requestPool.Put(req)
}
var filePool sync.Pool
// SetFileFunc The methods as follows is used by AcquireFile method.
// You can set file field via these method.
type SetFileFunc func(f *File)
// SetFileName method sets file name.
func SetFileName(n string) SetFileFunc {
return func(f *File) {
f.SetName(n)
}
}
// SetFileFieldName method sets key of file in the body.
func SetFileFieldName(p string) SetFileFunc {
return func(f *File) {
f.SetFieldName(p)
}
}
// SetFilePath method set file path.
func SetFilePath(p string) SetFileFunc {
return func(f *File) {
f.SetPath(p)
}
}
// SetFileReader method can receive a io.ReadCloser
func SetFileReader(r io.ReadCloser) SetFileFunc {
return func(f *File) {
f.SetReader(r)
}
}
// AcquireFile returns an File object from the pool.
// And you can set field in the File with SetFileFunc.
//
// The returned file may be returned to the pool with ReleaseFile when no longer needed.
// This allows reducing GC load.
func AcquireFile(setter ...SetFileFunc) *File {
fv := filePool.Get()
if fv != nil {
f, ok := fv.(*File)
if !ok {
panic(errors.New("failed to type-assert to *File"))
}
for _, v := range setter {
v(f)
}
return f
}
f := &File{}
for _, v := range setter {
v(f)
}
return f
}
// ReleaseFile returns the object acquired via AcquireFile to the pool.
//
// Do not access the released File object, otherwise data races may occur.
func ReleaseFile(f *File) {
f.Reset()
filePool.Put(f)
}
// SetValWithStruct Set some values using structs.
// `p` is a structure that implements the WithStruct interface,
// The field name can be specified by `tagName`.
// `v` is a struct include some data.
// Note: This method only supports simple types and nested structs are not currently supported.
func SetValWithStruct(p WithStruct, tagName string, v any) {
valueOfV := reflect.ValueOf(v)
typeOfV := reflect.TypeOf(v)
// The v should be struct or point of struct
if typeOfV.Kind() == reflect.Pointer && typeOfV.Elem().Kind() == reflect.Struct {
valueOfV = valueOfV.Elem()
typeOfV = typeOfV.Elem()
} else if typeOfV.Kind() != reflect.Struct {
return
}
// Boring type judge.
// TODO: cover more types and complex data structure.
var setVal func(name string, value reflect.Value)
setVal = func(name string, val reflect.Value) {
switch val.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
p.Add(name, strconv.Itoa(int(val.Int())))
case reflect.Bool:
if val.Bool() {
p.Add(name, "true")
}
case reflect.String:
p.Add(name, val.String())
case reflect.Float32, reflect.Float64:
p.Add(name, strconv.FormatFloat(val.Float(), 'f', -1, 64))
case reflect.Slice, reflect.Array:
for i := 0; i < val.Len(); i++ {
setVal(name, val.Index(i))
}
default:
}
}
for i := 0; i < typeOfV.NumField(); i++ {
field := typeOfV.Field(i)
if !field.IsExported() {
continue
}
name := field.Tag.Get(tagName)
if name == "" {
name = field.Name
}
val := valueOfV.Field(i)
if val.IsZero() {
continue
}
// To cover slice and array, we delete the val then add it.
p.Del(name)
setVal(name, val)
}
}

1623
client/request_test.go Normal file

File diff suppressed because it is too large Load Diff

184
client/response.go Normal file
View File

@ -0,0 +1,184 @@
package client
import (
"bytes"
"errors"
"fmt"
"io"
"io/fs"
"os"
"path/filepath"
"strings"
"sync"
"github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp"
)
// Response is the result of a request. This object is used to access the response data.
type Response struct {
client *Client
request *Request
cookie []*fasthttp.Cookie
RawResponse *fasthttp.Response
}
// setClient method sets client object in response instance.
// Use core object in the client.
func (r *Response) setClient(c *Client) {
r.client = c
}
// setRequest method sets Request object in response instance.
// The request will be released when the Response.Close is called.
func (r *Response) setRequest(req *Request) {
r.request = req
}
// Status method returns the HTTP status string for the executed request.
func (r *Response) Status() string {
return string(r.RawResponse.Header.StatusMessage())
}
// StatusCode method returns the HTTP status code for the executed request.
func (r *Response) StatusCode() int {
return r.RawResponse.StatusCode()
}
// Protocol method returns the HTTP response protocol used for the request.
func (r *Response) Protocol() string {
return string(r.RawResponse.Header.Protocol())
}
// Header method returns the response headers.
func (r *Response) Header(key string) string {
return utils.UnsafeString(r.RawResponse.Header.Peek(key))
}
// Cookies method to access all the response cookies.
func (r *Response) Cookies() []*fasthttp.Cookie {
return r.cookie
}
// Body method returns HTTP response as []byte array for the executed request.
func (r *Response) Body() []byte {
return r.RawResponse.Body()
}
// String method returns the body of the server response as String.
func (r *Response) String() string {
return strings.TrimSpace(string(r.Body()))
}
// JSON method will unmarshal body to json.
func (r *Response) JSON(v any) error {
return r.client.jsonUnmarshal(r.Body(), v)
}
// XML method will unmarshal body to xml.
func (r *Response) XML(v any) error {
return r.client.xmlUnmarshal(r.Body(), v)
}
// Save method will save the body to a file or io.Writer.
func (r *Response) Save(v any) error {
switch p := v.(type) {
case string:
file := filepath.Clean(p)
dir := filepath.Dir(file)
// create directory
if _, err := os.Stat(dir); err != nil {
if !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("failed to check directory: %w", err)
}
if err = os.MkdirAll(dir, 0o750); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
}
// create file
outFile, err := os.Create(file)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
defer func() { _ = outFile.Close() }() //nolint:errcheck // not needed
_, err = io.Copy(outFile, bytes.NewReader(r.Body()))
if err != nil {
return fmt.Errorf("failed to write response body to file: %w", err)
}
return nil
case io.Writer:
_, err := io.Copy(p, bytes.NewReader(r.Body()))
if err != nil {
return fmt.Errorf("failed to write response body to io.Writer: %w", err)
}
defer func() {
if pc, ok := p.(io.WriteCloser); ok {
_ = pc.Close() //nolint:errcheck // not needed
}
}()
return nil
default:
return ErrNotSupportSaveMethod
}
}
// Reset clear Response object.
func (r *Response) Reset() {
r.client = nil
r.request = nil
for len(r.cookie) != 0 {
t := r.cookie[0]
r.cookie = r.cookie[1:]
fasthttp.ReleaseCookie(t)
}
r.RawResponse.Reset()
}
// Close method will release Request object and Response object,
// after call Close please don't use these object.
func (r *Response) Close() {
if r.request != nil {
tmp := r.request
r.request = nil
ReleaseRequest(tmp)
}
ReleaseResponse(r)
}
var responsePool = &sync.Pool{
New: func() any {
return &Response{
cookie: []*fasthttp.Cookie{},
RawResponse: fasthttp.AcquireResponse(),
}
},
}
// AcquireResponse returns an empty response object from the pool.
//
// The returned response may be returned to the pool with ReleaseResponse when no longer needed.
// This allows reducing GC load.
func AcquireResponse() *Response {
resp, ok := responsePool.Get().(*Response)
if !ok {
panic("unexpected type from responsePool.Get()")
}
return resp
}
// ReleaseResponse returns the object acquired via AcquireResponse to the pool.
//
// Do not access the released Response object, otherwise data races may occur.
func ReleaseResponse(resp *Response) {
resp.Reset()
responsePool.Put(resp)
}

418
client/response_test.go Normal file
View File

@ -0,0 +1,418 @@
package client
import (
"bytes"
"crypto/tls"
"encoding/xml"
"io"
"net"
"os"
"testing"
"github.com/gofiber/fiber/v3/internal/tlstest"
"github.com/gofiber/fiber/v3"
"github.com/stretchr/testify/require"
)
func Test_Response_Status(t *testing.T) {
t.Parallel()
setupApp := func() *testServer {
server := startTestServer(t, func(app *fiber.App) {
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("foo")
})
app.Get("/fail", func(c fiber.Ctx) error {
return c.SendStatus(407)
})
})
return server
}
t.Run("success", func(t *testing.T) {
t.Parallel()
server := setupApp()
defer server.stop()
client := NewClient().SetDial(server.dial())
resp, err := AcquireRequest().
SetClient(client).
Get("http://example")
require.NoError(t, err)
require.Equal(t, "OK", resp.Status())
resp.Close()
})
t.Run("fail", func(t *testing.T) {
t.Parallel()
server := setupApp()
defer server.stop()
client := NewClient().SetDial(server.dial())
resp, err := AcquireRequest().
SetClient(client).
Get("http://example/fail")
require.NoError(t, err)
require.Equal(t, "Proxy Authentication Required", resp.Status())
resp.Close()
})
}
func Test_Response_Status_Code(t *testing.T) {
t.Parallel()
setupApp := func() *testServer {
server := startTestServer(t, func(app *fiber.App) {
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("foo")
})
app.Get("/fail", func(c fiber.Ctx) error {
return c.SendStatus(407)
})
})
return server
}
t.Run("success", func(t *testing.T) {
t.Parallel()
server := setupApp()
defer server.stop()
client := NewClient().SetDial(server.dial())
resp, err := AcquireRequest().
SetClient(client).
Get("http://example")
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode())
resp.Close()
})
t.Run("fail", func(t *testing.T) {
t.Parallel()
server := setupApp()
defer server.stop()
client := NewClient().SetDial(server.dial())
resp, err := AcquireRequest().
SetClient(client).
Get("http://example/fail")
require.NoError(t, err)
require.Equal(t, 407, resp.StatusCode())
resp.Close()
})
}
func Test_Response_Protocol(t *testing.T) {
t.Parallel()
t.Run("http", func(t *testing.T) {
t.Parallel()
server := startTestServer(t, func(app *fiber.App) {
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("foo")
})
})
defer server.stop()
client := NewClient().SetDial(server.dial())
resp, err := AcquireRequest().
SetClient(client).
Get("http://example")
require.NoError(t, err)
require.Equal(t, "HTTP/1.1", resp.Protocol())
resp.Close()
})
t.Run("https", func(t *testing.T) {
t.Parallel()
serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs()
require.NoError(t, err)
ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
require.NoError(t, err)
ln = tls.NewListener(ln, serverTLSConf)
app := fiber.New()
app.Get("/", func(c fiber.Ctx) error {
return c.SendString(c.Scheme())
})
go func() {
require.NoError(t, app.Listener(ln, fiber.ListenConfig{
DisableStartupMessage: true,
}))
}()
client := NewClient()
resp, err := client.SetTLSConfig(clientTLSConf).Get("https://" + ln.Addr().String())
require.NoError(t, err)
require.Equal(t, clientTLSConf, client.TLSConfig())
require.Equal(t, fiber.StatusOK, resp.StatusCode())
require.Equal(t, "https", resp.String())
require.Equal(t, "HTTP/1.1", resp.Protocol())
resp.Close()
})
}
func Test_Response_Header(t *testing.T) {
t.Parallel()
server := startTestServer(t, func(app *fiber.App) {
app.Get("/", func(c fiber.Ctx) error {
c.Response().Header.Add("foo", "bar")
return c.SendString("helo world")
})
})
defer server.stop()
client := NewClient().SetDial(server.dial())
resp, err := AcquireRequest().
SetClient(client).
Get("http://example.com")
require.NoError(t, err)
require.Equal(t, "bar", resp.Header("foo"))
resp.Close()
}
func Test_Response_Cookie(t *testing.T) {
t.Parallel()
server := startTestServer(t, func(app *fiber.App) {
app.Get("/", func(c fiber.Ctx) error {
c.Cookie(&fiber.Cookie{
Name: "foo",
Value: "bar",
})
return c.SendString("helo world")
})
})
defer server.stop()
client := NewClient().SetDial(server.dial())
resp, err := AcquireRequest().
SetClient(client).
Get("http://example.com")
require.NoError(t, err)
require.Equal(t, "bar", string(resp.Cookies()[0].Value()))
resp.Close()
}
func Test_Response_Body(t *testing.T) {
t.Parallel()
setupApp := func() *testServer {
server := startTestServer(t, func(app *fiber.App) {
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("hello world")
})
app.Get("/json", func(c fiber.Ctx) error {
return c.SendString("{\"status\":\"success\"}")
})
app.Get("/xml", func(c fiber.Ctx) error {
return c.SendString("<status><name>success</name></status>")
})
})
return server
}
t.Run("raw body", func(t *testing.T) {
t.Parallel()
server := setupApp()
defer server.stop()
client := NewClient().SetDial(server.dial())
resp, err := AcquireRequest().
SetClient(client).
Get("http://example.com")
require.NoError(t, err)
require.Equal(t, []byte("hello world"), resp.Body())
resp.Close()
})
t.Run("string body", func(t *testing.T) {
t.Parallel()
server := setupApp()
defer server.stop()
client := NewClient().SetDial(server.dial())
resp, err := AcquireRequest().
SetClient(client).
Get("http://example.com")
require.NoError(t, err)
require.Equal(t, "hello world", resp.String())
resp.Close()
})
t.Run("json body", func(t *testing.T) {
t.Parallel()
type body struct {
Status string `json:"status"`
}
server := setupApp()
defer server.stop()
client := NewClient().SetDial(server.dial())
resp, err := AcquireRequest().
SetClient(client).
Get("http://example.com/json")
require.NoError(t, err)
tmp := &body{}
err = resp.JSON(tmp)
require.NoError(t, err)
require.Equal(t, "success", tmp.Status)
resp.Close()
})
t.Run("xml body", func(t *testing.T) {
t.Parallel()
type body struct {
Name xml.Name `xml:"status"`
Status string `xml:"name"`
}
server := setupApp()
defer server.stop()
client := NewClient().SetDial(server.dial())
resp, err := AcquireRequest().
SetClient(client).
Get("http://example.com/xml")
require.NoError(t, err)
tmp := &body{}
err = resp.XML(tmp)
require.NoError(t, err)
require.Equal(t, "success", tmp.Status)
resp.Close()
})
}
func Test_Response_Save(t *testing.T) {
t.Parallel()
setupApp := func() *testServer {
server := startTestServer(t, func(app *fiber.App) {
app.Get("/json", func(c fiber.Ctx) error {
return c.SendString("{\"status\":\"success\"}")
})
})
return server
}
t.Run("file path", func(t *testing.T) {
t.Parallel()
server := setupApp()
defer server.stop()
client := NewClient().SetDial(server.dial())
resp, err := AcquireRequest().
SetClient(client).
Get("http://example.com/json")
require.NoError(t, err)
err = resp.Save("./test/tmp.json")
require.NoError(t, err)
defer func() {
_, err := os.Stat("./test/tmp.json")
require.NoError(t, err)
err = os.RemoveAll("./test")
require.NoError(t, err)
}()
file, err := os.Open("./test/tmp.json")
require.NoError(t, err)
defer func(file *os.File) {
err := file.Close()
require.NoError(t, err)
}(file)
data, err := io.ReadAll(file)
require.NoError(t, err)
require.Equal(t, "{\"status\":\"success\"}", string(data))
})
t.Run("io.Writer", func(t *testing.T) {
t.Parallel()
server := setupApp()
defer server.stop()
client := NewClient().SetDial(server.dial())
resp, err := AcquireRequest().
SetClient(client).
Get("http://example.com/json")
require.NoError(t, err)
buf := &bytes.Buffer{}
err = resp.Save(buf)
require.NoError(t, err)
require.Equal(t, "{\"status\":\"success\"}", buf.String())
})
t.Run("error type", func(t *testing.T) {
t.Parallel()
server := setupApp()
defer server.stop()
client := NewClient().SetDial(server.dial())
resp, err := AcquireRequest().
SetClient(client).
Get("http://example.com/json")
require.NoError(t, err)
err = resp.Save(nil)
require.Error(t, err)
})
}

File diff suppressed because it is too large Load Diff

View File

@ -17,6 +17,7 @@ import (
"time"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/fasthttputil"
)
@ -68,22 +69,30 @@ func Test_Listen_Graceful_Shutdown(t *testing.T) {
Time time.Duration
ExpectedBody string
ExpectedStatusCode int
ExceptedErrsLen int
ExpectedErr error
}{
{Time: 100 * time.Millisecond, ExpectedBody: "example.com", ExpectedStatusCode: StatusOK, ExceptedErrsLen: 0},
{Time: 500 * time.Millisecond, ExpectedBody: "", ExpectedStatusCode: 0, ExceptedErrsLen: 1},
{Time: 100 * time.Millisecond, ExpectedBody: "example.com", ExpectedStatusCode: StatusOK, ExpectedErr: nil},
{Time: 500 * time.Millisecond, ExpectedBody: "", ExpectedStatusCode: StatusOK, ExpectedErr: errors.New("InmemoryListener is already closed: use of closed network connection")},
}
for _, tc := range testCases {
time.Sleep(tc.Time)
a := Get("http://example.com")
a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
code, body, errs := a.String()
req := fasthttp.AcquireRequest()
req.SetRequestURI("http://example.com")
require.Equal(t, tc.ExpectedStatusCode, code)
require.Equal(t, tc.ExpectedBody, body)
require.Len(t, errs, tc.ExceptedErrsLen)
client := fasthttp.HostClient{}
client.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
resp := fasthttp.AcquireResponse()
err := client.Do(req, resp)
require.Equal(t, tc.ExpectedErr, err)
require.Equal(t, tc.ExpectedStatusCode, resp.StatusCode())
require.Equal(t, tc.ExpectedBody, string(resp.Body()))
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
}
mu.Lock()

View File

@ -34,6 +34,7 @@ func (l *defaultLogger) privateLog(lv Level, fmtArgs []any) {
if lv == LevelPanic {
panic(buf.String())
}
buf.Reset()
bytebufferpool.Put(buf)
if lv == LevelFatal {
@ -56,6 +57,7 @@ func (l *defaultLogger) privateLogf(lv Level, format string, fmtArgs []any) {
} else {
_, _ = fmt.Fprint(buf, fmtArgs...)
}
_ = l.stdlog.Output(l.depth, buf.String()) //nolint:errcheck // It is fine to ignore the error
if lv == LevelPanic {
panic(buf.String())

View File

@ -2,7 +2,6 @@ package proxy
import (
"bytes"
"crypto/tls"
"net/url"
"strings"
"sync"
@ -105,13 +104,6 @@ var client = &fasthttp.Client{
var lock sync.RWMutex
// WithTLSConfig update http client with a user specified tls.config
// This function should be called before Do and Forward.
// Deprecated: use WithClient instead.
func WithTLSConfig(tlsConfig *tls.Config) {
client.TLSConfig = tlsConfig
}
// WithClient sets the global proxy client.
// This function should be called before Do and Forward.
func WithClient(cli *fasthttp.Client) {

View File

@ -11,8 +11,10 @@ import (
"time"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/internal/tlstest"
clientpkg "github.com/gofiber/fiber/v3/client"
"github.com/stretchr/testify/require"
"github.com/gofiber/fiber/v3/internal/tlstest"
"github.com/valyala/fasthttp"
)
@ -25,8 +27,6 @@ func createProxyTestServer(t *testing.T, handler fiber.Handler) (*fiber.App, str
ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
require.NoError(t, err)
addr := ln.Addr().String()
go func() {
require.NoError(t, target.Listener(ln, fiber.ListenConfig{
DisableStartupMessage: true,
@ -34,6 +34,7 @@ func createProxyTestServer(t *testing.T, handler fiber.Handler) (*fiber.App, str
}()
time.Sleep(2 * time.Second)
addr := ln.Addr().String()
return target, addr
}
@ -104,8 +105,8 @@ func Test_Proxy(t *testing.T) {
require.Equal(t, fiber.StatusTeapot, resp.StatusCode)
}
// go test -run Test_Proxy_Balancer_WithTLSConfig
func Test_Proxy_Balancer_WithTLSConfig(t *testing.T) {
// go test -run Test_Proxy_Balancer_WithTlsConfig
func Test_Proxy_Balancer_WithTlsConfig(t *testing.T) {
t.Parallel()
serverTLSConf, _, err := tlstest.GetTLSConfigs()
@ -118,7 +119,7 @@ func Test_Proxy_Balancer_WithTLSConfig(t *testing.T) {
app := fiber.New()
app.Get("/tlsbalaner", func(c fiber.Ctx) error {
app.Get("/tlsbalancer", func(c fiber.Ctx) error {
return c.SendString("tls balancer")
})
@ -137,15 +138,18 @@ func Test_Proxy_Balancer_WithTLSConfig(t *testing.T) {
}))
}()
code, body, errs := fiber.Get("https://" + addr + "/tlsbalaner").TLSConfig(clientTLSConf).String()
client := clientpkg.NewClient()
client.SetTLSConfig(clientTLSConf)
require.Empty(t, errs)
require.Equal(t, fiber.StatusOK, code)
require.Equal(t, "tls balancer", body)
resp, err := client.Get("https://" + addr + "/tlsbalancer")
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode())
require.Equal(t, "tls balancer", string(resp.Body()))
resp.Close()
}
// go test -run Test_Proxy_Forward_WithTLSConfig_To_Http
func Test_Proxy_Forward_WithTLSConfig_To_Http(t *testing.T) {
// go test -run Test_Proxy_Forward_WithTlsConfig_To_Http
func Test_Proxy_Forward_WithTlsConfig_To_Http(t *testing.T) {
t.Parallel()
_, targetAddr := createProxyTestServer(t, func(c fiber.Ctx) error {
@ -172,14 +176,15 @@ func Test_Proxy_Forward_WithTLSConfig_To_Http(t *testing.T) {
}))
}()
code, body, errs := fiber.Get("https://" + proxyAddr).
InsecureSkipVerify().
Timeout(5 * time.Second).
String()
client := clientpkg.NewClient()
client.SetTimeout(5 * time.Second)
client.TLSConfig().InsecureSkipVerify = true
require.Empty(t, errs)
require.Equal(t, fiber.StatusOK, code)
require.Equal(t, "hello from target", body)
resp, err := client.Get("https://" + proxyAddr)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode())
require.Equal(t, "hello from target", string(resp.Body()))
resp.Close()
}
// go test -run Test_Proxy_Forward
@ -203,8 +208,8 @@ func Test_Proxy_Forward(t *testing.T) {
require.Equal(t, "forwarded", string(b))
}
// go test -run Test_Proxy_Forward_WithTLSConfig
func Test_Proxy_Forward_WithTLSConfig(t *testing.T) {
// go test -run Test_Proxy_Forward_WithClient_TLSConfig
func Test_Proxy_Forward_WithClient_TLSConfig(t *testing.T) {
t.Parallel()
serverTLSConf, _, err := tlstest.GetTLSConfigs()
@ -225,7 +230,9 @@ func Test_Proxy_Forward_WithTLSConfig(t *testing.T) {
clientTLSConf := &tls.Config{InsecureSkipVerify: true} //nolint:gosec // We're in a test func, so this is fine
// disable certificate verification
WithTLSConfig(clientTLSConf)
WithClient(&fasthttp.Client{
TLSConfig: clientTLSConf,
})
app.Use(Forward("https://" + addr + "/tlsfwd"))
go func() {
@ -234,11 +241,14 @@ func Test_Proxy_Forward_WithTLSConfig(t *testing.T) {
}))
}()
code, body, errs := fiber.Get("https://" + addr).TLSConfig(clientTLSConf).String()
client := clientpkg.NewClient()
client.SetTLSConfig(clientTLSConf)
require.Empty(t, errs)
require.Equal(t, fiber.StatusOK, code)
require.Equal(t, "tls forward", body)
resp, err := client.Get("https://" + addr)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode())
require.Equal(t, "tls forward", string(resp.Body()))
resp.Close()
}
// go test -run Test_Proxy_Modify_Response
@ -415,7 +425,7 @@ func Test_Proxy_Do_WithRedirect(t *testing.T) {
return Do(c, "https://google.com")
})
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil), 1500)
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
require.NoError(t, err1)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
@ -431,7 +441,7 @@ func Test_Proxy_DoRedirects_RestoreOriginalURL(t *testing.T) {
return DoRedirects(c, "http://google.com", 1)
})
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil), 1500)
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
require.NoError(t, err1)
_, err := io.ReadAll(resp.Body)
require.NoError(t, err)
@ -447,7 +457,7 @@ func Test_Proxy_DoRedirects_TooManyRedirects(t *testing.T) {
return DoRedirects(c, "http://google.com", 0)
})
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil), 1500)
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
require.NoError(t, err1)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
@ -586,10 +596,13 @@ func Test_Proxy_Forward_Global_Client(t *testing.T) {
}))
}()
code, body, errs := fiber.Get("http://" + addr).String()
require.Empty(t, errs)
require.Equal(t, fiber.StatusOK, code)
require.Equal(t, "test_global_client", body)
client := clientpkg.NewClient()
resp, err := client.Get("http://" + addr)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode())
require.Equal(t, "test_global_client", string(resp.Body()))
resp.Close()
}
// go test -race -run Test_Proxy_Forward_Local_Client
@ -615,10 +628,13 @@ func Test_Proxy_Forward_Local_Client(t *testing.T) {
}))
}()
code, body, errs := fiber.Get("http://" + addr).String()
require.Empty(t, errs)
require.Equal(t, fiber.StatusOK, code)
require.Equal(t, "test_local_client", body)
client := clientpkg.NewClient()
resp, err := client.Get("http://" + addr)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode())
require.Equal(t, "test_local_client", string(resp.Body()))
resp.Close()
}
// go test -run Test_ProxyBalancer_Custom_Client
@ -666,7 +682,7 @@ func Test_Proxy_Domain_Forward_Local(t *testing.T) {
app1 := fiber.New()
app1.Get("/test", func(c fiber.Ctx) error {
return c.SendString("test_local_client:" + fiber.Query[string](c, "query_test"))
return c.SendString("test_local_client:" + c.Query("query_test"))
})
proxyAddr := ln.Addr().String()
@ -679,13 +695,24 @@ func Test_Proxy_Domain_Forward_Local(t *testing.T) {
Dial: fasthttp.Dial,
}))
go func() { require.NoError(t, app.Listener(ln)) }()
go func() { require.NoError(t, app1.Listener(ln1)) }()
go func() {
require.NoError(t, app.Listener(ln, fiber.ListenConfig{
DisableStartupMessage: true,
}))
}()
go func() {
require.NoError(t, app1.Listener(ln1, fiber.ListenConfig{
DisableStartupMessage: true,
}))
}()
code, body, errs := fiber.Get("http://" + localDomain + "/test?query_test=true").String()
require.Empty(t, errs)
require.Equal(t, fiber.StatusOK, code)
require.Equal(t, "test_local_client:true", body)
client := clientpkg.NewClient()
resp, err := client.Get("http://" + localDomain + "/test?query_test=true")
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode())
require.Equal(t, "test_local_client:true", string(resp.Body()))
resp.Close()
}
// go test -run Test_Proxy_Balancer_Forward_Local

View File

@ -291,41 +291,45 @@ func Test_Redirect_Request(t *testing.T) {
CookieValue string
ExpectedBody string
ExpectedStatusCode int
ExceptedErrsLen int
ExpectedErr error
}{
{
URL: "/",
CookieValue: "key:value,key2:value2,co\\:m\\,ma:Fi\\:ber\\, v3",
ExpectedBody: `{"inputs":{},"messages":{"co:m,ma":"Fi:ber, v3","key":"value","key2":"value2"}}`,
ExpectedStatusCode: StatusOK,
ExceptedErrsLen: 0,
ExpectedErr: nil,
},
{
URL: "/with-inputs?name=john&surname=doe",
CookieValue: "key:value,key2:value2,key:value,key2:value2,old_input_data_name:john,old_input_data_surname:doe",
ExpectedBody: `{"inputs":{"name":"john","surname":"doe"},"messages":{"key":"value","key2":"value2"}}`,
ExpectedStatusCode: StatusOK,
ExceptedErrsLen: 0,
ExpectedErr: nil,
},
{
URL: "/just-inputs?name=john&surname=doe",
CookieValue: "old_input_data_name:john,old_input_data_surname:doe",
ExpectedBody: `{"inputs":{"name":"john","surname":"doe"},"messages":{}}`,
ExpectedStatusCode: StatusOK,
ExceptedErrsLen: 0,
ExpectedErr: nil,
},
}
for _, tc := range testCases {
a := Get("http://example.com" + tc.URL)
a.Cookie(FlashCookieName, tc.CookieValue)
a.MaxRedirectsCount(1)
a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
code, body, errs := a.String()
client := &fasthttp.HostClient{
Dial: func(_ string) (net.Conn, error) {
return ln.Dial()
},
}
req, resp := fasthttp.AcquireRequest(), fasthttp.AcquireResponse()
req.SetRequestURI("http://example.com" + tc.URL)
req.Header.SetCookie(FlashCookieName, tc.CookieValue)
err := client.DoRedirects(req, resp, 1)
require.Equal(t, tc.ExpectedStatusCode, code)
require.Equal(t, tc.ExpectedBody, body)
require.Len(t, errs, tc.ExceptedErrsLen)
require.NoError(t, err)
require.Equal(t, tc.ExpectedBody, string(resp.Body()))
require.Equal(t, tc.ExpectedStatusCode, resp.StatusCode())
}
}