mirror of
https://github.com/gofiber/fiber.git
synced 2025-07-27 04:30:18 +00:00
334 lines
8.8 KiB
Go
334 lines
8.8 KiB
Go
// The code was originally taken from https://github.com/valyala/fasthttp/pull/526.
|
|
package client
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gofiber/utils/v2"
|
|
"github.com/valyala/fasthttp"
|
|
)
|
|
|
|
var cookieJarPool = sync.Pool{
|
|
New: func() any {
|
|
return &CookieJar{}
|
|
},
|
|
}
|
|
|
|
// AcquireCookieJar returns an empty CookieJar object from the pool.
|
|
func AcquireCookieJar() *CookieJar {
|
|
jar, ok := cookieJarPool.Get().(*CookieJar)
|
|
if !ok {
|
|
panic(errors.New("failed to type-assert to *CookieJar"))
|
|
}
|
|
|
|
return jar
|
|
}
|
|
|
|
// ReleaseCookieJar returns a CookieJar object to the pool.
|
|
func ReleaseCookieJar(c *CookieJar) {
|
|
c.Release()
|
|
cookieJarPool.Put(c)
|
|
}
|
|
|
|
// CookieJar manages cookie storage for the client. It stores cookies keyed by host.
|
|
type CookieJar struct {
|
|
hostCookies map[string][]*fasthttp.Cookie
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// Get returns all cookies stored for a given URI. If there are no cookies for the
|
|
// provided host, the returned slice will be nil.
|
|
//
|
|
// The CookieJar keeps its own copies of cookies, so it is safe to release the returned
|
|
// cookies after use.
|
|
func (cj *CookieJar) Get(uri *fasthttp.URI) []*fasthttp.Cookie {
|
|
if uri == nil {
|
|
return nil
|
|
}
|
|
|
|
secure := bytes.Equal(uri.Scheme(), []byte("https"))
|
|
return cj.getByHostAndPath(uri.Host(), uri.Path(), secure)
|
|
}
|
|
|
|
// getByHostAndPath returns cookies stored for a specific host and path.
|
|
func (cj *CookieJar) getByHostAndPath(host, path []byte, secure bool) []*fasthttp.Cookie {
|
|
if cj.hostCookies == nil {
|
|
return nil
|
|
}
|
|
|
|
var (
|
|
err error
|
|
hostStr = utils.UnsafeString(host)
|
|
)
|
|
|
|
// port must not be included.
|
|
hostStr, _, err = net.SplitHostPort(hostStr)
|
|
if err != nil {
|
|
hostStr = utils.UnsafeString(host)
|
|
}
|
|
return cj.cookiesForRequest(hostStr, path, secure)
|
|
}
|
|
|
|
// getCookiesByHost returns cookies stored for a specific host, removing any that have expired.
|
|
func (cj *CookieJar) getCookiesByHost(host string) []*fasthttp.Cookie {
|
|
cj.mu.Lock()
|
|
defer cj.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
cookies := cj.hostCookies[host]
|
|
|
|
kept := cookies[:0]
|
|
for _, c := range cookies {
|
|
// Remove expired cookies.
|
|
if !c.Expire().Equal(fasthttp.CookieExpireUnlimited) && c.Expire().Before(now) {
|
|
fasthttp.ReleaseCookie(c)
|
|
continue
|
|
}
|
|
kept = append(kept, c)
|
|
}
|
|
cj.hostCookies[host] = kept
|
|
|
|
return kept
|
|
}
|
|
|
|
// cookiesForRequest returns cookies that match the given host, path and security settings.
|
|
//
|
|
//nolint:revive // secure is required to filter Secure cookies based on scheme
|
|
func (cj *CookieJar) cookiesForRequest(host string, path []byte, secure bool) []*fasthttp.Cookie {
|
|
cj.mu.Lock()
|
|
defer cj.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
var matched []*fasthttp.Cookie
|
|
|
|
for domain, cookies := range cj.hostCookies {
|
|
if !domainMatch(host, domain) {
|
|
continue
|
|
}
|
|
|
|
kept := cookies[:0]
|
|
for _, c := range cookies {
|
|
if !c.Expire().Equal(fasthttp.CookieExpireUnlimited) && c.Expire().Before(now) {
|
|
fasthttp.ReleaseCookie(c)
|
|
continue
|
|
}
|
|
kept = append(kept, c)
|
|
|
|
if !pathMatch(path, c.Path()) {
|
|
continue
|
|
}
|
|
if c.Secure() && !secure {
|
|
continue
|
|
}
|
|
nc := fasthttp.AcquireCookie()
|
|
nc.CopyTo(c)
|
|
matched = append(matched, nc)
|
|
}
|
|
cj.hostCookies[domain] = kept
|
|
}
|
|
|
|
return matched
|
|
}
|
|
|
|
// Set stores the given cookies for the specified URI host. If a cookie key already exists,
|
|
// it will be replaced by the new cookie value.
|
|
//
|
|
// CookieJar stores copies of the provided cookies, so they may be safely released after use.
|
|
func (cj *CookieJar) Set(uri *fasthttp.URI, cookies ...*fasthttp.Cookie) {
|
|
if uri == nil {
|
|
return
|
|
}
|
|
cj.SetByHost(uri.Host(), cookies...)
|
|
}
|
|
|
|
// SetByHost stores the given cookies for the specified host. If a cookie key already exists,
|
|
// it will be replaced by the new cookie value.
|
|
//
|
|
// CookieJar stores copies of the provided cookies, so they may be safely released after use.
|
|
func (cj *CookieJar) SetByHost(host []byte, cookies ...*fasthttp.Cookie) {
|
|
hostStr := utils.UnsafeString(host)
|
|
if h, _, err := net.SplitHostPort(hostStr); err == nil {
|
|
hostStr = h
|
|
}
|
|
hostStr = utils.ToLower(hostStr)
|
|
hostKey := utils.CopyString(hostStr)
|
|
|
|
cj.mu.Lock()
|
|
defer cj.mu.Unlock()
|
|
|
|
if cj.hostCookies == nil {
|
|
cj.hostCookies = make(map[string][]*fasthttp.Cookie)
|
|
}
|
|
|
|
for _, cookie := range cookies {
|
|
domain := utils.TrimLeft(cookie.Domain(), '.')
|
|
utils.ToLowerBytes(domain)
|
|
key := hostKey
|
|
if len(domain) == 0 {
|
|
cookie.SetDomain(hostStr)
|
|
} else {
|
|
key = utils.CopyString(utils.UnsafeString(domain))
|
|
cookie.SetDomainBytes(domain)
|
|
}
|
|
|
|
hostCookies := cj.hostCookies[key]
|
|
|
|
existing := searchCookieByKeyAndPath(cookie.Key(), cookie.Path(), hostCookies)
|
|
if existing == nil {
|
|
existing = fasthttp.AcquireCookie()
|
|
hostCookies = append(hostCookies, existing)
|
|
}
|
|
existing.CopyTo(cookie)
|
|
cj.hostCookies[key] = hostCookies
|
|
}
|
|
}
|
|
|
|
// SetKeyValue sets a cookie for the specified host with the given key and value.
|
|
//
|
|
// This function helps prevent extra allocations by avoiding duplication of repeated cookies.
|
|
func (cj *CookieJar) SetKeyValue(host, key, value string) {
|
|
c := fasthttp.AcquireCookie()
|
|
c.SetKey(key)
|
|
c.SetValue(value)
|
|
|
|
cj.SetByHost(utils.UnsafeBytes(host), c)
|
|
}
|
|
|
|
// SetKeyValueBytes sets a cookie for the specified host using byte slices for the key and value.
|
|
//
|
|
// This function helps prevent extra allocations by avoiding duplication of repeated cookies.
|
|
func (cj *CookieJar) SetKeyValueBytes(host string, key, value []byte) {
|
|
c := fasthttp.AcquireCookie()
|
|
c.SetKeyBytes(key)
|
|
c.SetValueBytes(value)
|
|
|
|
cj.SetByHost(utils.UnsafeBytes(host), c)
|
|
}
|
|
|
|
// dumpCookiesToReq writes the stored cookies to the given request.
|
|
func (cj *CookieJar) dumpCookiesToReq(req *fasthttp.Request) {
|
|
uri := req.URI()
|
|
secure := bytes.Equal(uri.Scheme(), []byte("https"))
|
|
cookies := cj.getByHostAndPath(uri.Host(), uri.Path(), secure)
|
|
for _, cookie := range cookies {
|
|
req.Header.SetCookieBytesKV(cookie.Key(), cookie.Value())
|
|
fasthttp.ReleaseCookie(cookie)
|
|
}
|
|
}
|
|
|
|
// parseCookiesFromResp parses the cookies from the response and stores them for the specified host and path.
|
|
func (cj *CookieJar) parseCookiesFromResp(host, _ []byte, resp *fasthttp.Response) {
|
|
hostStr := utils.UnsafeString(host)
|
|
if h, _, err := net.SplitHostPort(hostStr); err == nil {
|
|
hostStr = h
|
|
}
|
|
hostStr = utils.ToLower(hostStr)
|
|
hostKey := utils.CopyString(hostStr)
|
|
|
|
cj.mu.Lock()
|
|
defer cj.mu.Unlock()
|
|
|
|
if cj.hostCookies == nil {
|
|
cj.hostCookies = make(map[string][]*fasthttp.Cookie)
|
|
}
|
|
|
|
now := time.Now()
|
|
for _, value := range resp.Header.Cookies() {
|
|
tmp := fasthttp.AcquireCookie()
|
|
_ = tmp.ParseBytes(value) //nolint:errcheck // ignore error
|
|
|
|
domainBytes := utils.TrimLeft(tmp.Domain(), '.')
|
|
utils.ToLowerBytes(domainBytes)
|
|
key := hostKey
|
|
if len(domainBytes) == 0 {
|
|
tmp.SetDomain(hostStr)
|
|
} else {
|
|
key = utils.CopyString(utils.UnsafeString(domainBytes))
|
|
tmp.SetDomainBytes(domainBytes)
|
|
}
|
|
|
|
cookies := cj.hostCookies[key]
|
|
c := searchCookieByKeyAndPath(tmp.Key(), tmp.Path(), cookies)
|
|
if c == nil {
|
|
c = fasthttp.AcquireCookie()
|
|
cookies = append(cookies, c)
|
|
}
|
|
|
|
c.CopyTo(tmp)
|
|
if c.Expire().Equal(fasthttp.CookieExpireUnlimited) || c.Expire().After(now) {
|
|
cj.hostCookies[key] = cookies
|
|
} else {
|
|
kept := cookies[:0]
|
|
for _, v := range cookies {
|
|
if v != c {
|
|
kept = append(kept, v)
|
|
}
|
|
}
|
|
cj.hostCookies[key] = kept
|
|
fasthttp.ReleaseCookie(c)
|
|
}
|
|
fasthttp.ReleaseCookie(tmp)
|
|
}
|
|
}
|
|
|
|
// Release releases all stored cookies. After this, the CookieJar is empty.
|
|
func (cj *CookieJar) Release() {
|
|
// FOLLOW-UP performance optimization:
|
|
// Currently, a race condition is found because the reset method modifies a value
|
|
// that is not a copy but a reference. A solution would be to make a copy.
|
|
// for _, v := range cj.hostCookies {
|
|
// for _, c := range v {
|
|
// fasthttp.ReleaseCookie(c)
|
|
// }
|
|
// }
|
|
cj.hostCookies = nil
|
|
}
|
|
|
|
// searchCookieByKeyAndPath looks up a cookie by its key and path from the provided slice of cookies.
|
|
func searchCookieByKeyAndPath(key, path []byte, cookies []*fasthttp.Cookie) *fasthttp.Cookie {
|
|
for _, c := range cookies {
|
|
if bytes.Equal(key, c.Key()) {
|
|
if pathMatch(path, c.Path()) {
|
|
return c
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// pathMatch determines whether the request path matches the cookie path
|
|
// according to RFC 6265 section 5.1.4.
|
|
func pathMatch(reqPath, cookiePath []byte) bool {
|
|
if len(reqPath) == 0 {
|
|
reqPath = []byte("/")
|
|
}
|
|
if len(cookiePath) == 0 {
|
|
cookiePath = []byte("/")
|
|
}
|
|
if bytes.Equal(reqPath, cookiePath) {
|
|
return true
|
|
}
|
|
if !bytes.HasPrefix(reqPath, cookiePath) {
|
|
return false
|
|
}
|
|
if cookiePath[len(cookiePath)-1] == '/' {
|
|
return true
|
|
}
|
|
return len(reqPath) > len(cookiePath) && reqPath[len(cookiePath)] == '/'
|
|
}
|
|
|
|
// domainMatch reports whether host domain-matches the given cookie domain.
|
|
func domainMatch(host, domain string) bool {
|
|
host = utils.ToLower(host)
|
|
|
|
if host == domain {
|
|
return true
|
|
}
|
|
return strings.HasSuffix(host, "."+domain)
|
|
}
|