💼 introduce sessions

pull/1009/head
Fenny 2020-11-06 19:32:56 +01:00
parent 0483406a12
commit 8dd663175e
7 changed files with 823 additions and 0 deletions

1
go.mod
View File

@ -3,6 +3,7 @@ module github.com/gofiber/fiber/v2
go 1.14
require (
github.com/klauspost/compress v1.11.0 // indirect
github.com/valyala/fasthttp v1.17.0
golang.org/x/sys v0.0.0-20201101102859-da207088b7d1
)

2
go.sum
View File

@ -2,6 +2,8 @@ github.com/andybalholm/brotli v1.0.0 h1:7UCwP93aiSfvWpapti8g88vVVGp2qqtGyePsSuDa
github.com/andybalholm/brotli v1.0.0/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y=
github.com/klauspost/compress v1.10.7 h1:7rix8v8GpI3ZBb0nSozFRgbtXKv+hOe+qfEpZqybrAg=
github.com/klauspost/compress v1.10.7/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
github.com/klauspost/compress v1.11.0 h1:wJbzvpYMVGG9iTI9VxpnNZfd4DzMPoCWze3GgSqz8yg=
github.com/klauspost/compress v1.11.0/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.17.0 h1:P8/koH4aSnJ4xbd0cUUFEGQs3jQqIxoDDyRQrUiAkqg=

103
middleware/sessions/db.go Normal file
View File

@ -0,0 +1,103 @@
package sessions
import (
"sync"
)
// go:generate msgp
// msgp -file="db.go" -o="db_msgp.go" -tests=false -unexported
// don't forget to replace the msgp import path to:
// "github.com/gofiber/fiber/v2/internal/msgp"
type db struct {
d []kv
}
// go:generate msgp
type kv struct {
k string
v interface{}
}
var dbPool = sync.Pool{
New: func() interface{} {
return new(db)
},
}
func acquireDB() *db {
return dbPool.Get().(*db)
}
func releaseDB(d *db) {
d.Reset()
dbPool.Put(d)
}
func (d *db) Reset() {
d.d = d.d[:0]
}
func (d *db) Get(key string) interface{} {
idx := d.indexOf(key)
if idx > -1 {
return d.d[idx].v
}
return nil
}
func (d *db) Set(key string, value interface{}) {
idx := d.indexOf(key)
if idx > -1 {
kv := &d.d[idx]
kv.v = value
} else {
d.append(key, value)
}
}
func (d *db) Delete(key string) {
idx := d.indexOf(key)
if idx > -1 {
n := len(d.d) - 1
d.swap(idx, n)
d.d = d.d[:n]
}
}
func (d *db) Len() int {
return len(d.d)
}
func (d *db) swap(i, j int) {
iKey, iValue := d.d[i].k, d.d[i].v
jKey, jValue := d.d[j].k, d.d[j].v
d.d[i].k, d.d[i].v = jKey, jValue
d.d[j].k, d.d[j].v = iKey, iValue
}
func (d *db) allocPage() *kv {
n := len(d.d)
if cap(d.d) > n {
d.d = d.d[:n+1]
} else {
d.d = append(d.d, kv{})
}
return &d.d[n]
}
func (d *db) append(key string, value interface{}) {
kv := d.allocPage()
kv.k = key
kv.v = value
}
func (d *db) indexOf(key string) int {
n := len(d.d)
for i := 0; i < n; i++ {
if d.d[i].k == key {
return i
}
}
return -1
}

View File

@ -0,0 +1,365 @@
package sessions
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
import (
"github.com/gofiber/fiber/v2/internal/msgp"
)
// DecodeMsg implements msgp.Decodable
func (z *db) DecodeMsg(dc *msgp.Reader) (err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, err = dc.ReadMapHeader()
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, err = dc.ReadMapKeyPtr()
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "d":
var zb0002 uint32
zb0002, err = dc.ReadArrayHeader()
if err != nil {
err = msgp.WrapError(err, "d")
return
}
if cap(z.d) >= int(zb0002) {
z.d = (z.d)[:zb0002]
} else {
z.d = make([]kv, zb0002)
}
for za0001 := range z.d {
var zb0003 uint32
zb0003, err = dc.ReadMapHeader()
if err != nil {
err = msgp.WrapError(err, "d", za0001)
return
}
for zb0003 > 0 {
zb0003--
field, err = dc.ReadMapKeyPtr()
if err != nil {
err = msgp.WrapError(err, "d", za0001)
return
}
switch msgp.UnsafeString(field) {
case "k":
z.d[za0001].k, err = dc.ReadString()
if err != nil {
err = msgp.WrapError(err, "d", za0001, "k")
return
}
case "v":
z.d[za0001].v, err = dc.ReadIntf()
if err != nil {
err = msgp.WrapError(err, "d", za0001, "v")
return
}
default:
err = dc.Skip()
if err != nil {
err = msgp.WrapError(err, "d", za0001)
return
}
}
}
}
default:
err = dc.Skip()
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
return
}
// EncodeMsg implements msgp.Encodable
func (z *db) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 1
// write "d"
err = en.Append(0x81, 0xa1, 0x64)
if err != nil {
return
}
err = en.WriteArrayHeader(uint32(len(z.d)))
if err != nil {
err = msgp.WrapError(err, "d")
return
}
for za0001 := range z.d {
// map header, size 2
// write "k"
err = en.Append(0x82, 0xa1, 0x6b)
if err != nil {
return
}
err = en.WriteString(z.d[za0001].k)
if err != nil {
err = msgp.WrapError(err, "d", za0001, "k")
return
}
// write "v"
err = en.Append(0xa1, 0x76)
if err != nil {
return
}
err = en.WriteIntf(z.d[za0001].v)
if err != nil {
err = msgp.WrapError(err, "d", za0001, "v")
return
}
}
return
}
// MarshalMsg implements msgp.Marshaler
func (z *db) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize())
// map header, size 1
// string "d"
o = append(o, 0x81, 0xa1, 0x64)
o = msgp.AppendArrayHeader(o, uint32(len(z.d)))
for za0001 := range z.d {
// map header, size 2
// string "k"
o = append(o, 0x82, 0xa1, 0x6b)
o = msgp.AppendString(o, z.d[za0001].k)
// string "v"
o = append(o, 0xa1, 0x76)
o, err = msgp.AppendIntf(o, z.d[za0001].v)
if err != nil {
err = msgp.WrapError(err, "d", za0001, "v")
return
}
}
return
}
// UnmarshalMsg implements msgp.Unmarshaler
func (z *db) UnmarshalMsg(bts []byte) (o []byte, err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, bts, err = msgp.ReadMapKeyZC(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "d":
var zb0002 uint32
zb0002, bts, err = msgp.ReadArrayHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err, "d")
return
}
if cap(z.d) >= int(zb0002) {
z.d = (z.d)[:zb0002]
} else {
z.d = make([]kv, zb0002)
}
for za0001 := range z.d {
var zb0003 uint32
zb0003, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err, "d", za0001)
return
}
for zb0003 > 0 {
zb0003--
field, bts, err = msgp.ReadMapKeyZC(bts)
if err != nil {
err = msgp.WrapError(err, "d", za0001)
return
}
switch msgp.UnsafeString(field) {
case "k":
z.d[za0001].k, bts, err = msgp.ReadStringBytes(bts)
if err != nil {
err = msgp.WrapError(err, "d", za0001, "k")
return
}
case "v":
z.d[za0001].v, bts, err = msgp.ReadIntfBytes(bts)
if err != nil {
err = msgp.WrapError(err, "d", za0001, "v")
return
}
default:
bts, err = msgp.Skip(bts)
if err != nil {
err = msgp.WrapError(err, "d", za0001)
return
}
}
}
}
default:
bts, err = msgp.Skip(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
o = bts
return
}
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z *db) Msgsize() (s int) {
s = 1 + 2 + msgp.ArrayHeaderSize
for za0001 := range z.d {
s += 1 + 2 + msgp.StringPrefixSize + len(z.d[za0001].k) + 2 + msgp.GuessSize(z.d[za0001].v)
}
return
}
// DecodeMsg implements msgp.Decodable
func (z *kv) DecodeMsg(dc *msgp.Reader) (err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, err = dc.ReadMapHeader()
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, err = dc.ReadMapKeyPtr()
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "k":
z.k, err = dc.ReadString()
if err != nil {
err = msgp.WrapError(err, "k")
return
}
case "v":
z.v, err = dc.ReadIntf()
if err != nil {
err = msgp.WrapError(err, "v")
return
}
default:
err = dc.Skip()
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
return
}
// EncodeMsg implements msgp.Encodable
func (z kv) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 2
// write "k"
err = en.Append(0x82, 0xa1, 0x6b)
if err != nil {
return
}
err = en.WriteString(z.k)
if err != nil {
err = msgp.WrapError(err, "k")
return
}
// write "v"
err = en.Append(0xa1, 0x76)
if err != nil {
return
}
err = en.WriteIntf(z.v)
if err != nil {
err = msgp.WrapError(err, "v")
return
}
return
}
// MarshalMsg implements msgp.Marshaler
func (z kv) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize())
// map header, size 2
// string "k"
o = append(o, 0x82, 0xa1, 0x6b)
o = msgp.AppendString(o, z.k)
// string "v"
o = append(o, 0xa1, 0x76)
o, err = msgp.AppendIntf(o, z.v)
if err != nil {
err = msgp.WrapError(err, "v")
return
}
return
}
// UnmarshalMsg implements msgp.Unmarshaler
func (z *kv) UnmarshalMsg(bts []byte) (o []byte, err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, bts, err = msgp.ReadMapKeyZC(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "k":
z.k, bts, err = msgp.ReadStringBytes(bts)
if err != nil {
err = msgp.WrapError(err, "k")
return
}
case "v":
z.v, bts, err = msgp.ReadIntfBytes(bts)
if err != nil {
err = msgp.WrapError(err, "v")
return
}
default:
bts, err = msgp.Skip(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
o = bts
return
}
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z kv) Msgsize() (s int) {
s = 1 + 2 + msgp.StringPrefixSize + len(z.k) + 2 + msgp.GuessSize(z.v)
return
}

View File

@ -0,0 +1,113 @@
package sessions
import (
"errors"
"sync"
"time"
)
// copy of https://github.com/gofiber/storage/tree/main/memory
type memory struct {
mux sync.RWMutex
db map[string]memoryEntry
gcInterval time.Duration
done chan struct{}
}
var errNotExist = errors.New("key does not exist")
type memoryEntry struct {
data []byte
expiry int64
}
func memoryStorage() *memory {
// Create storage
store := &memory{
db: make(map[string]memoryEntry),
gcInterval: 10 * time.Second,
done: make(chan struct{}),
}
// Start garbage collector
go store.gc()
return store
}
// Get value by key
func (s *memory) Get(key string) ([]byte, error) {
s.mux.RLock()
v, ok := s.db[key]
s.mux.RUnlock()
if !ok || v.expiry != 0 && v.expiry <= time.Now().Unix() {
return nil, errNotExist
}
return v.data, nil
}
// Set key with value
func (s *memory) Set(key string, val []byte, exp time.Duration) error {
// Ain't Nobody Got Time For That
if len(key) <= 0 || len(val) <= 0 {
return nil
}
var expire int64
if exp != 0 {
expire = time.Now().Add(exp).Unix()
}
s.mux.Lock()
s.db[key] = memoryEntry{val, expire}
s.mux.Unlock()
return nil
}
// Delete key by key
func (s *memory) Delete(key string) error {
// Ain't Nobody Got Time For That
if len(key) <= 0 {
return nil
}
s.mux.Lock()
delete(s.db, key)
s.mux.Unlock()
return nil
}
// Reset all keys
func (s *memory) Reset() error {
s.mux.Lock()
s.db = make(map[string]memoryEntry)
s.mux.Unlock()
return nil
}
// Close the memory storage
func (s *memory) Close() error {
s.done <- struct{}{}
return nil
}
func (s *memory) gc() {
ticker := time.NewTicker(s.gcInterval)
defer ticker.Stop()
for {
select {
case <-s.done:
return
case t := <-ticker.C:
now := t.Unix()
s.mux.Lock()
for id, v := range s.db {
if v.expiry != 0 && v.expiry < now {
delete(s.db, id)
}
}
s.mux.Unlock()
}
}
}

View File

@ -0,0 +1,128 @@
package sessions
import (
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
type Session struct {
ctx *fiber.Ctx
sessions *Sessions
db *db
id string
fresh bool
}
// Fresh is true if the current session is new or existing
func (s *Session) Fresh() bool {
return s.fresh
}
// ID returns the session id
func (s *Session) ID() string {
return s.id
}
// Get will return the value
func (s *Session) Get(key string) interface{} {
return s.db.Get(key)
}
// Set will update or create a new key value
func (s *Session) Set(key string, val interface{}) {
s.db.Set(key, val)
}
// Delete will delete the value
func (s *Session) Delete(key string) {
s.db.Delete(key)
}
// Reset will clear the session and remove from storage
func (s *Session) Reset() error {
s.db.Reset()
return s.sessions.cfg.Store.Delete(s.id)
}
// Save will update the storage and client cookie
func (s *Session) Save() error {
// Expire session if no data is present ( aka reset )
if s.db.Len() <= 0 {
// Delete cookie
s.deleteCookie()
return nil
}
// Convert book to bytes
data, err := s.db.MarshalMsg(nil)
if err != nil {
return err
}
// pass raw bytes with session id to provider
if err := s.sessions.cfg.Store.Set(s.id, data, s.sessions.cfg.Expiration); err != nil {
return err
}
// release db back to pool to be re-used on next request
releaseDB(s.db)
// Create cookie with the session ID
s.setCookie()
return nil
}
func (s *Session) setCookie() {
fcookie := fasthttp.AcquireCookie()
fcookie.SetKey(s.sessions.cfg.Cookie.Name)
fcookie.SetValue(s.id)
fcookie.SetPath(s.sessions.cfg.Cookie.Path)
fcookie.SetDomain(s.sessions.cfg.Cookie.Domain)
fcookie.SetMaxAge(int(s.sessions.cfg.Expiration))
fcookie.SetExpire(time.Now().Add(s.sessions.cfg.Expiration))
fcookie.SetSecure(s.sessions.cfg.Cookie.Secure)
fcookie.SetHTTPOnly(s.sessions.cfg.Cookie.HTTPOnly)
switch utils.ToLower(s.sessions.cfg.Cookie.SameSite) {
case "strict":
fcookie.SetSameSite(fasthttp.CookieSameSiteStrictMode)
case "none":
fcookie.SetSameSite(fasthttp.CookieSameSiteNoneMode)
default:
fcookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
}
s.ctx.Response().Header.SetCookie(fcookie)
fasthttp.ReleaseCookie(fcookie)
}
func (s *Session) deleteCookie() {
s.ctx.Request().Header.DelCookie(s.sessions.cfg.Cookie.Name)
s.ctx.Response().Header.DelCookie(s.sessions.cfg.Cookie.Name)
fcookie := fasthttp.AcquireCookie()
fcookie.SetKey(s.sessions.cfg.Cookie.Name)
fcookie.SetValue(s.id)
fcookie.SetPath(s.sessions.cfg.Cookie.Path)
fcookie.SetDomain(s.sessions.cfg.Cookie.Domain)
fcookie.SetMaxAge(int(s.sessions.cfg.Expiration))
fcookie.SetExpire(time.Now().Add(-1 * time.Minute))
fcookie.SetSecure(s.sessions.cfg.Cookie.Secure)
fcookie.SetHTTPOnly(s.sessions.cfg.Cookie.HTTPOnly)
switch utils.ToLower(s.sessions.cfg.Cookie.SameSite) {
case "strict":
fcookie.SetSameSite(fasthttp.CookieSameSiteStrictMode)
case "none":
fcookie.SetSameSite(fasthttp.CookieSameSiteNoneMode)
default:
fcookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
}
s.ctx.Response().Header.SetCookie(fcookie)
fasthttp.ReleaseCookie(fcookie)
}

View File

@ -0,0 +1,111 @@
package sessions
import (
"fmt"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
)
type Config struct {
// Possible values:
// - "header:<name>"
// - "query:<name>"
// - "param:<name>"
// - "form:<name>"
// - "cookie:<name>"
//
// Optional. Default value "cookie:_csrf".
// TODO: When to override Cookie.Value?
KeyLookup string
// Optional. Session ID generator function.
//
// Default: utils.UUID
KeyGenerator func() string
// Optional. Cookie to set values on
//
// NOTE: Value, MaxAge and Expires will be overriden by the session ID and expiration
// TODO: Should this be a pointer, if yes why?
Cookie fiber.Cookie
// Allowed session duration
//
// Optional. Default: 24 hours
Expiration time.Duration
// Store interface
// Optional. Default: memory.New
Store fiber.Storage
}
var ConfigDefault = Config{
Cookie: fiber.Cookie{
Value: "session_id",
},
Expiration: 30 * time.Minute,
KeyGenerator: utils.UUID,
}
type Sessions struct {
cfg Config
}
func New(config ...Config) *Sessions {
cfg := ConfigDefault
if len(config) > 0 {
cfg = config[0]
}
if cfg.Store == nil {
cfg.Store = memoryStorage()
}
return &Sessions{
cfg: cfg,
}
}
func (s *Sessions) Get(c *fiber.Ctx) *Session {
var fresh bool
// Get ID from cookie
id := c.Cookies(s.cfg.Cookie.Name)
// If no ID exist, create new one
if len(id) == 0 {
id = s.cfg.KeyGenerator()
fresh = true
}
// Create session object
sess := &Session{
ctx: c,
sessions: s,
fresh: fresh,
db: acquireDB(),
id: id,
}
// Fetch existing data
if !fresh {
raw, err := s.cfg.Store.Get(id)
// Set data
if err == nil {
_, err := sess.db.UnmarshalMsg(raw)
if err != nil {
fmt.Println(err)
}
}
}
// Return session object
return sess
}
func (s *Sessions) Reset() error {
return s.cfg.Store.Reset()
}