fiber/client/cookiejar.go

246 lines
6.5 KiB
Go

// The code has been taken from https://github.com/valyala/fasthttp/pull/526 originally.
package client
import (
"bytes"
"errors"
"net"
"sync"
"time"
"github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp"
)
var cookieJarPool = sync.Pool{
New: func() any {
return &CookieJar{}
},
}
// AcquireCookieJar returns an empty CookieJar object from pool.
func AcquireCookieJar() *CookieJar {
jar, ok := cookieJarPool.Get().(*CookieJar)
if !ok {
panic(errors.New("failed to type-assert to *CookieJar"))
}
return jar
}
// ReleaseCookieJar returns CookieJar to the pool.
func ReleaseCookieJar(c *CookieJar) {
c.Release()
cookieJarPool.Put(c)
}
// CookieJar manages cookie storage. It is used by the client to store cookies.
type CookieJar struct {
hostCookies map[string][]*fasthttp.Cookie
mu sync.Mutex
}
// Get returns the cookies stored from a specific domain.
// If there were no cookies related with host returned slice will be nil.
//
// CookieJar keeps a copy of the cookies, so the returned cookies can be released safely.
func (cj *CookieJar) Get(uri *fasthttp.URI) []*fasthttp.Cookie {
if uri == nil {
return nil
}
return cj.getByHostAndPath(uri.Host(), uri.Path())
}
// get returns the cookies stored from a specific host and path.
func (cj *CookieJar) getByHostAndPath(host, path []byte) []*fasthttp.Cookie {
if cj.hostCookies == nil {
return nil
}
var (
err error
cookies []*fasthttp.Cookie
hostStr = utils.UnsafeString(host)
)
// port must not be included.
hostStr, _, err = net.SplitHostPort(hostStr)
if err != nil {
hostStr = utils.UnsafeString(host)
}
// get cookies deleting expired ones
cookies = cj.getCookiesByHost(hostStr)
newCookies := make([]*fasthttp.Cookie, 0, len(cookies))
for i := 0; i < len(cookies); i++ {
cookie := cookies[i]
if len(path) > 1 && len(cookie.Path()) > 1 && !bytes.HasPrefix(cookie.Path(), path) {
continue
}
newCookies = append(newCookies, cookie)
}
return newCookies
}
// getCookiesByHost returns the cookies stored from a specific host.
// If cookies are expired they will be deleted.
func (cj *CookieJar) getCookiesByHost(host string) []*fasthttp.Cookie {
cj.mu.Lock()
defer cj.mu.Unlock()
now := time.Now()
cookies := cj.hostCookies[host]
for i := 0; i < len(cookies); i++ {
c := cookies[i]
if !c.Expire().Equal(fasthttp.CookieExpireUnlimited) && c.Expire().Before(now) { // release cookie if expired
cookies = append(cookies[:i], cookies[i+1:]...)
fasthttp.ReleaseCookie(c)
i--
}
}
return cookies
}
// Set sets cookies for a specific host.
// The host is get from uri.Host().
// If the cookie key already exists it will be replaced by the new cookie value.
//
// CookieJar keeps a copy of the cookies, so the parsed cookies can be released safely.
func (cj *CookieJar) Set(uri *fasthttp.URI, cookies ...*fasthttp.Cookie) {
if uri == nil {
return
}
cj.SetByHost(uri.Host(), cookies...)
}
// SetByHost sets cookies for a specific host.
// If the cookie key already exists it will be replaced by the new cookie value.
//
// CookieJar keeps a copy of the cookies, so the parsed cookies can be released safely.
func (cj *CookieJar) SetByHost(host []byte, cookies ...*fasthttp.Cookie) {
hostStr := utils.UnsafeString(host)
cj.mu.Lock()
defer cj.mu.Unlock()
if cj.hostCookies == nil {
cj.hostCookies = make(map[string][]*fasthttp.Cookie)
}
hostCookies, ok := cj.hostCookies[hostStr]
if !ok {
// If the key does not exist in the map, then we must make a copy for the key to avoid unsafe usage.
hostStr = string(host)
}
for _, cookie := range cookies {
c := searchCookieByKeyAndPath(cookie.Key(), cookie.Path(), hostCookies)
if c == nil {
// If the cookie does not exist in the slice, let's acquire new cookie and store it.
c = fasthttp.AcquireCookie()
hostCookies = append(hostCookies, c)
}
c.CopyTo(cookie) // override cookie properties
}
cj.hostCookies[hostStr] = hostCookies
}
// SetKeyValue sets a cookie by key and value for a specific host.
//
// This function prevents extra allocations by making repeated cookies
// not being duplicated.
func (cj *CookieJar) SetKeyValue(host, key, value string) {
c := fasthttp.AcquireCookie()
c.SetKey(key)
c.SetValue(value)
cj.SetByHost(utils.UnsafeBytes(host), c)
}
// SetKeyValueBytes sets a cookie by key and value for a specific host.
//
// This function prevents extra allocations by making repeated cookies
// not being duplicated.
func (cj *CookieJar) SetKeyValueBytes(host string, key, value []byte) {
c := fasthttp.AcquireCookie()
c.SetKeyBytes(key)
c.SetValueBytes(value)
cj.SetByHost(utils.UnsafeBytes(host), c)
}
// dumpCookiesToReq dumps the stored cookies to the request.
func (cj *CookieJar) dumpCookiesToReq(req *fasthttp.Request) {
uri := req.URI()
cookies := cj.getByHostAndPath(uri.Host(), uri.Path())
for _, cookie := range cookies {
req.Header.SetCookieBytesKV(cookie.Key(), cookie.Value())
}
}
// parseCookiesFromResp parses the response cookies and stores them.
func (cj *CookieJar) parseCookiesFromResp(host, path []byte, resp *fasthttp.Response) {
hostStr := utils.UnsafeString(host)
cj.mu.Lock()
defer cj.mu.Unlock()
if cj.hostCookies == nil {
cj.hostCookies = make(map[string][]*fasthttp.Cookie)
}
cookies, ok := cj.hostCookies[hostStr]
if !ok {
// If the key does not exist in the map then
// we must make a copy for the key to avoid unsafe usage.
hostStr = string(host)
}
now := time.Now()
resp.Header.VisitAllCookie(func(key, value []byte) {
isCreated := false
c := searchCookieByKeyAndPath(key, path, cookies)
if c == nil {
c, isCreated = fasthttp.AcquireCookie(), true
}
_ = c.ParseBytes(value) //nolint:errcheck // ignore error
if c.Expire().Equal(fasthttp.CookieExpireUnlimited) || c.Expire().After(now) {
cookies = append(cookies, c)
} else if isCreated {
fasthttp.ReleaseCookie(c)
}
})
cj.hostCookies[hostStr] = cookies
}
// Release releases all cookie values.
func (cj *CookieJar) Release() {
// FOllOW-UP performance optimization
// currently a race condition is found because the reset method modifies a value which is not a copy but a reference -> solution should be to make a copy
// for _, v := range cj.hostCookies {
// for _, c := range v {
// fasthttp.ReleaseCookie(c)
// }
// }
cj.hostCookies = nil
}
// searchCookieByKeyAndPath searches for a cookie by key and path.
func searchCookieByKeyAndPath(key, path []byte, cookies []*fasthttp.Cookie) *fasthttp.Cookie {
for _, c := range cookies {
if bytes.Equal(key, c.Key()) {
if len(path) <= 1 || bytes.HasPrefix(c.Path(), path) {
return c
}
}
}
return nil
}