fiber/client/cookiejar.go
2025-07-17 14:48:08 +02:00

334 lines
8.8 KiB
Go

// The code was originally taken from https://github.com/valyala/fasthttp/pull/526.
package client
import (
"bytes"
"errors"
"net"
"strings"
"sync"
"time"
"github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp"
)
var cookieJarPool = sync.Pool{
New: func() any {
return &CookieJar{}
},
}
// AcquireCookieJar returns an empty CookieJar object from the pool.
func AcquireCookieJar() *CookieJar {
jar, ok := cookieJarPool.Get().(*CookieJar)
if !ok {
panic(errors.New("failed to type-assert to *CookieJar"))
}
return jar
}
// ReleaseCookieJar returns a CookieJar object to the pool.
func ReleaseCookieJar(c *CookieJar) {
c.Release()
cookieJarPool.Put(c)
}
// CookieJar manages cookie storage for the client. It stores cookies keyed by host.
type CookieJar struct {
hostCookies map[string][]*fasthttp.Cookie
mu sync.Mutex
}
// Get returns all cookies stored for a given URI. If there are no cookies for the
// provided host, the returned slice will be nil.
//
// The CookieJar keeps its own copies of cookies, so it is safe to release the returned
// cookies after use.
func (cj *CookieJar) Get(uri *fasthttp.URI) []*fasthttp.Cookie {
if uri == nil {
return nil
}
secure := bytes.Equal(uri.Scheme(), []byte("https"))
return cj.getByHostAndPath(uri.Host(), uri.Path(), secure)
}
// getByHostAndPath returns cookies stored for a specific host and path.
func (cj *CookieJar) getByHostAndPath(host, path []byte, secure bool) []*fasthttp.Cookie {
if cj.hostCookies == nil {
return nil
}
var (
err error
hostStr = utils.UnsafeString(host)
)
// port must not be included.
hostStr, _, err = net.SplitHostPort(hostStr)
if err != nil {
hostStr = utils.UnsafeString(host)
}
return cj.cookiesForRequest(hostStr, path, secure)
}
// getCookiesByHost returns cookies stored for a specific host, removing any that have expired.
func (cj *CookieJar) getCookiesByHost(host string) []*fasthttp.Cookie {
cj.mu.Lock()
defer cj.mu.Unlock()
now := time.Now()
cookies := cj.hostCookies[host]
kept := cookies[:0]
for _, c := range cookies {
// Remove expired cookies.
if !c.Expire().Equal(fasthttp.CookieExpireUnlimited) && c.Expire().Before(now) {
fasthttp.ReleaseCookie(c)
continue
}
kept = append(kept, c)
}
cj.hostCookies[host] = kept
return kept
}
// cookiesForRequest returns cookies that match the given host, path and security settings.
//
//nolint:revive // secure is required to filter Secure cookies based on scheme
func (cj *CookieJar) cookiesForRequest(host string, path []byte, secure bool) []*fasthttp.Cookie {
cj.mu.Lock()
defer cj.mu.Unlock()
now := time.Now()
var matched []*fasthttp.Cookie
for domain, cookies := range cj.hostCookies {
if !domainMatch(host, domain) {
continue
}
kept := cookies[:0]
for _, c := range cookies {
if !c.Expire().Equal(fasthttp.CookieExpireUnlimited) && c.Expire().Before(now) {
fasthttp.ReleaseCookie(c)
continue
}
kept = append(kept, c)
if !pathMatch(path, c.Path()) {
continue
}
if c.Secure() && !secure {
continue
}
nc := fasthttp.AcquireCookie()
nc.CopyTo(c)
matched = append(matched, nc)
}
cj.hostCookies[domain] = kept
}
return matched
}
// Set stores the given cookies for the specified URI host. If a cookie key already exists,
// it will be replaced by the new cookie value.
//
// CookieJar stores copies of the provided cookies, so they may be safely released after use.
func (cj *CookieJar) Set(uri *fasthttp.URI, cookies ...*fasthttp.Cookie) {
if uri == nil {
return
}
cj.SetByHost(uri.Host(), cookies...)
}
// SetByHost stores the given cookies for the specified host. If a cookie key already exists,
// it will be replaced by the new cookie value.
//
// CookieJar stores copies of the provided cookies, so they may be safely released after use.
func (cj *CookieJar) SetByHost(host []byte, cookies ...*fasthttp.Cookie) {
hostStr := utils.UnsafeString(host)
if h, _, err := net.SplitHostPort(hostStr); err == nil {
hostStr = h
}
hostStr = utils.ToLower(hostStr)
hostKey := utils.CopyString(hostStr)
cj.mu.Lock()
defer cj.mu.Unlock()
if cj.hostCookies == nil {
cj.hostCookies = make(map[string][]*fasthttp.Cookie)
}
for _, cookie := range cookies {
domain := utils.TrimLeft(cookie.Domain(), '.')
utils.ToLowerBytes(domain)
key := hostKey
if len(domain) == 0 {
cookie.SetDomain(hostStr)
} else {
key = utils.CopyString(utils.UnsafeString(domain))
cookie.SetDomainBytes(domain)
}
hostCookies := cj.hostCookies[key]
existing := searchCookieByKeyAndPath(cookie.Key(), cookie.Path(), hostCookies)
if existing == nil {
existing = fasthttp.AcquireCookie()
hostCookies = append(hostCookies, existing)
}
existing.CopyTo(cookie)
cj.hostCookies[key] = hostCookies
}
}
// SetKeyValue sets a cookie for the specified host with the given key and value.
//
// This function helps prevent extra allocations by avoiding duplication of repeated cookies.
func (cj *CookieJar) SetKeyValue(host, key, value string) {
c := fasthttp.AcquireCookie()
c.SetKey(key)
c.SetValue(value)
cj.SetByHost(utils.UnsafeBytes(host), c)
}
// SetKeyValueBytes sets a cookie for the specified host using byte slices for the key and value.
//
// This function helps prevent extra allocations by avoiding duplication of repeated cookies.
func (cj *CookieJar) SetKeyValueBytes(host string, key, value []byte) {
c := fasthttp.AcquireCookie()
c.SetKeyBytes(key)
c.SetValueBytes(value)
cj.SetByHost(utils.UnsafeBytes(host), c)
}
// dumpCookiesToReq writes the stored cookies to the given request.
func (cj *CookieJar) dumpCookiesToReq(req *fasthttp.Request) {
uri := req.URI()
secure := bytes.Equal(uri.Scheme(), []byte("https"))
cookies := cj.getByHostAndPath(uri.Host(), uri.Path(), secure)
for _, cookie := range cookies {
req.Header.SetCookieBytesKV(cookie.Key(), cookie.Value())
fasthttp.ReleaseCookie(cookie)
}
}
// parseCookiesFromResp parses the cookies from the response and stores them for the specified host and path.
func (cj *CookieJar) parseCookiesFromResp(host, _ []byte, resp *fasthttp.Response) {
hostStr := utils.UnsafeString(host)
if h, _, err := net.SplitHostPort(hostStr); err == nil {
hostStr = h
}
hostStr = utils.ToLower(hostStr)
hostKey := utils.CopyString(hostStr)
cj.mu.Lock()
defer cj.mu.Unlock()
if cj.hostCookies == nil {
cj.hostCookies = make(map[string][]*fasthttp.Cookie)
}
now := time.Now()
for _, value := range resp.Header.Cookies() {
tmp := fasthttp.AcquireCookie()
_ = tmp.ParseBytes(value) //nolint:errcheck // ignore error
domainBytes := utils.TrimLeft(tmp.Domain(), '.')
utils.ToLowerBytes(domainBytes)
key := hostKey
if len(domainBytes) == 0 {
tmp.SetDomain(hostStr)
} else {
key = utils.CopyString(utils.UnsafeString(domainBytes))
tmp.SetDomainBytes(domainBytes)
}
cookies := cj.hostCookies[key]
c := searchCookieByKeyAndPath(tmp.Key(), tmp.Path(), cookies)
if c == nil {
c = fasthttp.AcquireCookie()
cookies = append(cookies, c)
}
c.CopyTo(tmp)
if c.Expire().Equal(fasthttp.CookieExpireUnlimited) || c.Expire().After(now) {
cj.hostCookies[key] = cookies
} else {
kept := cookies[:0]
for _, v := range cookies {
if v != c {
kept = append(kept, v)
}
}
cj.hostCookies[key] = kept
fasthttp.ReleaseCookie(c)
}
fasthttp.ReleaseCookie(tmp)
}
}
// Release releases all stored cookies. After this, the CookieJar is empty.
func (cj *CookieJar) Release() {
// FOLLOW-UP performance optimization:
// Currently, a race condition is found because the reset method modifies a value
// that is not a copy but a reference. A solution would be to make a copy.
// for _, v := range cj.hostCookies {
// for _, c := range v {
// fasthttp.ReleaseCookie(c)
// }
// }
cj.hostCookies = nil
}
// searchCookieByKeyAndPath looks up a cookie by its key and path from the provided slice of cookies.
func searchCookieByKeyAndPath(key, path []byte, cookies []*fasthttp.Cookie) *fasthttp.Cookie {
for _, c := range cookies {
if bytes.Equal(key, c.Key()) {
if pathMatch(path, c.Path()) {
return c
}
}
}
return nil
}
// pathMatch determines whether the request path matches the cookie path
// according to RFC 6265 section 5.1.4.
func pathMatch(reqPath, cookiePath []byte) bool {
if len(reqPath) == 0 {
reqPath = []byte("/")
}
if len(cookiePath) == 0 {
cookiePath = []byte("/")
}
if bytes.Equal(reqPath, cookiePath) {
return true
}
if !bytes.HasPrefix(reqPath, cookiePath) {
return false
}
if cookiePath[len(cookiePath)-1] == '/' {
return true
}
return len(reqPath) > len(cookiePath) && reqPath[len(cookiePath)] == '/'
}
// domainMatch reports whether host domain-matches the given cookie domain.
func domainMatch(host, domain string) bool {
host = utils.ToLower(host)
if host == domain {
return true
}
return strings.HasSuffix(host, "."+domain)
}