mirror of https://github.com/gofiber/fiber.git
💼 introduce sessions
parent
0483406a12
commit
8dd663175e
1
go.mod
1
go.mod
|
@ -3,6 +3,7 @@ module github.com/gofiber/fiber/v2
|
|||
go 1.14
|
||||
|
||||
require (
|
||||
github.com/klauspost/compress v1.11.0 // indirect
|
||||
github.com/valyala/fasthttp v1.17.0
|
||||
golang.org/x/sys v0.0.0-20201101102859-da207088b7d1
|
||||
)
|
||||
|
|
2
go.sum
2
go.sum
|
@ -2,6 +2,8 @@ github.com/andybalholm/brotli v1.0.0 h1:7UCwP93aiSfvWpapti8g88vVVGp2qqtGyePsSuDa
|
|||
github.com/andybalholm/brotli v1.0.0/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y=
|
||||
github.com/klauspost/compress v1.10.7 h1:7rix8v8GpI3ZBb0nSozFRgbtXKv+hOe+qfEpZqybrAg=
|
||||
github.com/klauspost/compress v1.10.7/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
|
||||
github.com/klauspost/compress v1.11.0 h1:wJbzvpYMVGG9iTI9VxpnNZfd4DzMPoCWze3GgSqz8yg=
|
||||
github.com/klauspost/compress v1.11.0/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
github.com/valyala/fasthttp v1.17.0 h1:P8/koH4aSnJ4xbd0cUUFEGQs3jQqIxoDDyRQrUiAkqg=
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
package sessions
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// go:generate msgp
|
||||
// msgp -file="db.go" -o="db_msgp.go" -tests=false -unexported
|
||||
// don't forget to replace the msgp import path to:
|
||||
// "github.com/gofiber/fiber/v2/internal/msgp"
|
||||
type db struct {
|
||||
d []kv
|
||||
}
|
||||
|
||||
// go:generate msgp
|
||||
type kv struct {
|
||||
k string
|
||||
v interface{}
|
||||
}
|
||||
|
||||
var dbPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(db)
|
||||
},
|
||||
}
|
||||
|
||||
func acquireDB() *db {
|
||||
return dbPool.Get().(*db)
|
||||
}
|
||||
|
||||
func releaseDB(d *db) {
|
||||
d.Reset()
|
||||
dbPool.Put(d)
|
||||
}
|
||||
|
||||
func (d *db) Reset() {
|
||||
d.d = d.d[:0]
|
||||
}
|
||||
|
||||
func (d *db) Get(key string) interface{} {
|
||||
idx := d.indexOf(key)
|
||||
if idx > -1 {
|
||||
return d.d[idx].v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *db) Set(key string, value interface{}) {
|
||||
idx := d.indexOf(key)
|
||||
if idx > -1 {
|
||||
kv := &d.d[idx]
|
||||
kv.v = value
|
||||
} else {
|
||||
d.append(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *db) Delete(key string) {
|
||||
idx := d.indexOf(key)
|
||||
if idx > -1 {
|
||||
n := len(d.d) - 1
|
||||
d.swap(idx, n)
|
||||
d.d = d.d[:n]
|
||||
}
|
||||
}
|
||||
|
||||
func (d *db) Len() int {
|
||||
return len(d.d)
|
||||
}
|
||||
|
||||
func (d *db) swap(i, j int) {
|
||||
iKey, iValue := d.d[i].k, d.d[i].v
|
||||
jKey, jValue := d.d[j].k, d.d[j].v
|
||||
|
||||
d.d[i].k, d.d[i].v = jKey, jValue
|
||||
d.d[j].k, d.d[j].v = iKey, iValue
|
||||
}
|
||||
|
||||
func (d *db) allocPage() *kv {
|
||||
n := len(d.d)
|
||||
if cap(d.d) > n {
|
||||
d.d = d.d[:n+1]
|
||||
} else {
|
||||
d.d = append(d.d, kv{})
|
||||
}
|
||||
return &d.d[n]
|
||||
}
|
||||
|
||||
func (d *db) append(key string, value interface{}) {
|
||||
kv := d.allocPage()
|
||||
kv.k = key
|
||||
kv.v = value
|
||||
}
|
||||
|
||||
func (d *db) indexOf(key string) int {
|
||||
n := len(d.d)
|
||||
for i := 0; i < n; i++ {
|
||||
if d.d[i].k == key {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
|
@ -0,0 +1,365 @@
|
|||
package sessions
|
||||
|
||||
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2/internal/msgp"
|
||||
)
|
||||
|
||||
// DecodeMsg implements msgp.Decodable
|
||||
func (z *db) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
zb0001, err = dc.ReadMapHeader()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
for zb0001 > 0 {
|
||||
zb0001--
|
||||
field, err = dc.ReadMapKeyPtr()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
case "d":
|
||||
var zb0002 uint32
|
||||
zb0002, err = dc.ReadArrayHeader()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d")
|
||||
return
|
||||
}
|
||||
if cap(z.d) >= int(zb0002) {
|
||||
z.d = (z.d)[:zb0002]
|
||||
} else {
|
||||
z.d = make([]kv, zb0002)
|
||||
}
|
||||
for za0001 := range z.d {
|
||||
var zb0003 uint32
|
||||
zb0003, err = dc.ReadMapHeader()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d", za0001)
|
||||
return
|
||||
}
|
||||
for zb0003 > 0 {
|
||||
zb0003--
|
||||
field, err = dc.ReadMapKeyPtr()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d", za0001)
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
case "k":
|
||||
z.d[za0001].k, err = dc.ReadString()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d", za0001, "k")
|
||||
return
|
||||
}
|
||||
case "v":
|
||||
z.d[za0001].v, err = dc.ReadIntf()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d", za0001, "v")
|
||||
return
|
||||
}
|
||||
default:
|
||||
err = dc.Skip()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d", za0001)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
err = dc.Skip()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// EncodeMsg implements msgp.Encodable
|
||||
func (z *db) EncodeMsg(en *msgp.Writer) (err error) {
|
||||
// map header, size 1
|
||||
// write "d"
|
||||
err = en.Append(0x81, 0xa1, 0x64)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = en.WriteArrayHeader(uint32(len(z.d)))
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d")
|
||||
return
|
||||
}
|
||||
for za0001 := range z.d {
|
||||
// map header, size 2
|
||||
// write "k"
|
||||
err = en.Append(0x82, 0xa1, 0x6b)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = en.WriteString(z.d[za0001].k)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d", za0001, "k")
|
||||
return
|
||||
}
|
||||
// write "v"
|
||||
err = en.Append(0xa1, 0x76)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = en.WriteIntf(z.d[za0001].v)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d", za0001, "v")
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalMsg implements msgp.Marshaler
|
||||
func (z *db) MarshalMsg(b []byte) (o []byte, err error) {
|
||||
o = msgp.Require(b, z.Msgsize())
|
||||
// map header, size 1
|
||||
// string "d"
|
||||
o = append(o, 0x81, 0xa1, 0x64)
|
||||
o = msgp.AppendArrayHeader(o, uint32(len(z.d)))
|
||||
for za0001 := range z.d {
|
||||
// map header, size 2
|
||||
// string "k"
|
||||
o = append(o, 0x82, 0xa1, 0x6b)
|
||||
o = msgp.AppendString(o, z.d[za0001].k)
|
||||
// string "v"
|
||||
o = append(o, 0xa1, 0x76)
|
||||
o, err = msgp.AppendIntf(o, z.d[za0001].v)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d", za0001, "v")
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// UnmarshalMsg implements msgp.Unmarshaler
|
||||
func (z *db) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
for zb0001 > 0 {
|
||||
zb0001--
|
||||
field, bts, err = msgp.ReadMapKeyZC(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
case "d":
|
||||
var zb0002 uint32
|
||||
zb0002, bts, err = msgp.ReadArrayHeaderBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d")
|
||||
return
|
||||
}
|
||||
if cap(z.d) >= int(zb0002) {
|
||||
z.d = (z.d)[:zb0002]
|
||||
} else {
|
||||
z.d = make([]kv, zb0002)
|
||||
}
|
||||
for za0001 := range z.d {
|
||||
var zb0003 uint32
|
||||
zb0003, bts, err = msgp.ReadMapHeaderBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d", za0001)
|
||||
return
|
||||
}
|
||||
for zb0003 > 0 {
|
||||
zb0003--
|
||||
field, bts, err = msgp.ReadMapKeyZC(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d", za0001)
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
case "k":
|
||||
z.d[za0001].k, bts, err = msgp.ReadStringBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d", za0001, "k")
|
||||
return
|
||||
}
|
||||
case "v":
|
||||
z.d[za0001].v, bts, err = msgp.ReadIntfBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d", za0001, "v")
|
||||
return
|
||||
}
|
||||
default:
|
||||
bts, err = msgp.Skip(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d", za0001)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
bts, err = msgp.Skip(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
o = bts
|
||||
return
|
||||
}
|
||||
|
||||
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
|
||||
func (z *db) Msgsize() (s int) {
|
||||
s = 1 + 2 + msgp.ArrayHeaderSize
|
||||
for za0001 := range z.d {
|
||||
s += 1 + 2 + msgp.StringPrefixSize + len(z.d[za0001].k) + 2 + msgp.GuessSize(z.d[za0001].v)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// DecodeMsg implements msgp.Decodable
|
||||
func (z *kv) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
zb0001, err = dc.ReadMapHeader()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
for zb0001 > 0 {
|
||||
zb0001--
|
||||
field, err = dc.ReadMapKeyPtr()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
case "k":
|
||||
z.k, err = dc.ReadString()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "k")
|
||||
return
|
||||
}
|
||||
case "v":
|
||||
z.v, err = dc.ReadIntf()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "v")
|
||||
return
|
||||
}
|
||||
default:
|
||||
err = dc.Skip()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// EncodeMsg implements msgp.Encodable
|
||||
func (z kv) EncodeMsg(en *msgp.Writer) (err error) {
|
||||
// map header, size 2
|
||||
// write "k"
|
||||
err = en.Append(0x82, 0xa1, 0x6b)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = en.WriteString(z.k)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "k")
|
||||
return
|
||||
}
|
||||
// write "v"
|
||||
err = en.Append(0xa1, 0x76)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = en.WriteIntf(z.v)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "v")
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalMsg implements msgp.Marshaler
|
||||
func (z kv) MarshalMsg(b []byte) (o []byte, err error) {
|
||||
o = msgp.Require(b, z.Msgsize())
|
||||
// map header, size 2
|
||||
// string "k"
|
||||
o = append(o, 0x82, 0xa1, 0x6b)
|
||||
o = msgp.AppendString(o, z.k)
|
||||
// string "v"
|
||||
o = append(o, 0xa1, 0x76)
|
||||
o, err = msgp.AppendIntf(o, z.v)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "v")
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// UnmarshalMsg implements msgp.Unmarshaler
|
||||
func (z *kv) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
for zb0001 > 0 {
|
||||
zb0001--
|
||||
field, bts, err = msgp.ReadMapKeyZC(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
case "k":
|
||||
z.k, bts, err = msgp.ReadStringBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "k")
|
||||
return
|
||||
}
|
||||
case "v":
|
||||
z.v, bts, err = msgp.ReadIntfBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "v")
|
||||
return
|
||||
}
|
||||
default:
|
||||
bts, err = msgp.Skip(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
o = bts
|
||||
return
|
||||
}
|
||||
|
||||
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
|
||||
func (z kv) Msgsize() (s int) {
|
||||
s = 1 + 2 + msgp.StringPrefixSize + len(z.k) + 2 + msgp.GuessSize(z.v)
|
||||
return
|
||||
}
|
|
@ -0,0 +1,113 @@
|
|||
package sessions
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// copy of https://github.com/gofiber/storage/tree/main/memory
|
||||
type memory struct {
|
||||
mux sync.RWMutex
|
||||
db map[string]memoryEntry
|
||||
gcInterval time.Duration
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
var errNotExist = errors.New("key does not exist")
|
||||
|
||||
type memoryEntry struct {
|
||||
data []byte
|
||||
expiry int64
|
||||
}
|
||||
|
||||
func memoryStorage() *memory {
|
||||
// Create storage
|
||||
store := &memory{
|
||||
db: make(map[string]memoryEntry),
|
||||
gcInterval: 10 * time.Second,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Start garbage collector
|
||||
go store.gc()
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
// Get value by key
|
||||
func (s *memory) Get(key string) ([]byte, error) {
|
||||
s.mux.RLock()
|
||||
v, ok := s.db[key]
|
||||
s.mux.RUnlock()
|
||||
if !ok || v.expiry != 0 && v.expiry <= time.Now().Unix() {
|
||||
return nil, errNotExist
|
||||
}
|
||||
|
||||
return v.data, nil
|
||||
}
|
||||
|
||||
// Set key with value
|
||||
func (s *memory) Set(key string, val []byte, exp time.Duration) error {
|
||||
// Ain't Nobody Got Time For That
|
||||
if len(key) <= 0 || len(val) <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var expire int64
|
||||
if exp != 0 {
|
||||
expire = time.Now().Add(exp).Unix()
|
||||
}
|
||||
|
||||
s.mux.Lock()
|
||||
s.db[key] = memoryEntry{val, expire}
|
||||
s.mux.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete key by key
|
||||
func (s *memory) Delete(key string) error {
|
||||
// Ain't Nobody Got Time For That
|
||||
if len(key) <= 0 {
|
||||
return nil
|
||||
}
|
||||
s.mux.Lock()
|
||||
delete(s.db, key)
|
||||
s.mux.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset all keys
|
||||
func (s *memory) Reset() error {
|
||||
s.mux.Lock()
|
||||
s.db = make(map[string]memoryEntry)
|
||||
s.mux.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close the memory storage
|
||||
func (s *memory) Close() error {
|
||||
s.done <- struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memory) gc() {
|
||||
ticker := time.NewTicker(s.gcInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.done:
|
||||
return
|
||||
case t := <-ticker.C:
|
||||
now := t.Unix()
|
||||
s.mux.Lock()
|
||||
for id, v := range s.db {
|
||||
if v.expiry != 0 && v.expiry < now {
|
||||
delete(s.db, id)
|
||||
}
|
||||
}
|
||||
s.mux.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,128 @@
|
|||
package sessions
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
ctx *fiber.Ctx
|
||||
sessions *Sessions
|
||||
db *db
|
||||
id string
|
||||
fresh bool
|
||||
}
|
||||
|
||||
// Fresh is true if the current session is new or existing
|
||||
func (s *Session) Fresh() bool {
|
||||
return s.fresh
|
||||
}
|
||||
|
||||
// ID returns the session id
|
||||
func (s *Session) ID() string {
|
||||
return s.id
|
||||
}
|
||||
|
||||
// Get will return the value
|
||||
func (s *Session) Get(key string) interface{} {
|
||||
return s.db.Get(key)
|
||||
}
|
||||
|
||||
// Set will update or create a new key value
|
||||
func (s *Session) Set(key string, val interface{}) {
|
||||
s.db.Set(key, val)
|
||||
}
|
||||
|
||||
// Delete will delete the value
|
||||
func (s *Session) Delete(key string) {
|
||||
s.db.Delete(key)
|
||||
}
|
||||
|
||||
// Reset will clear the session and remove from storage
|
||||
func (s *Session) Reset() error {
|
||||
s.db.Reset()
|
||||
return s.sessions.cfg.Store.Delete(s.id)
|
||||
}
|
||||
|
||||
// Save will update the storage and client cookie
|
||||
func (s *Session) Save() error {
|
||||
// Expire session if no data is present ( aka reset )
|
||||
if s.db.Len() <= 0 {
|
||||
// Delete cookie
|
||||
s.deleteCookie()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Convert book to bytes
|
||||
data, err := s.db.MarshalMsg(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// pass raw bytes with session id to provider
|
||||
if err := s.sessions.cfg.Store.Set(s.id, data, s.sessions.cfg.Expiration); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// release db back to pool to be re-used on next request
|
||||
releaseDB(s.db)
|
||||
|
||||
// Create cookie with the session ID
|
||||
s.setCookie()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) setCookie() {
|
||||
fcookie := fasthttp.AcquireCookie()
|
||||
fcookie.SetKey(s.sessions.cfg.Cookie.Name)
|
||||
fcookie.SetValue(s.id)
|
||||
fcookie.SetPath(s.sessions.cfg.Cookie.Path)
|
||||
fcookie.SetDomain(s.sessions.cfg.Cookie.Domain)
|
||||
fcookie.SetMaxAge(int(s.sessions.cfg.Expiration))
|
||||
fcookie.SetExpire(time.Now().Add(s.sessions.cfg.Expiration))
|
||||
fcookie.SetSecure(s.sessions.cfg.Cookie.Secure)
|
||||
fcookie.SetHTTPOnly(s.sessions.cfg.Cookie.HTTPOnly)
|
||||
|
||||
switch utils.ToLower(s.sessions.cfg.Cookie.SameSite) {
|
||||
case "strict":
|
||||
fcookie.SetSameSite(fasthttp.CookieSameSiteStrictMode)
|
||||
case "none":
|
||||
fcookie.SetSameSite(fasthttp.CookieSameSiteNoneMode)
|
||||
default:
|
||||
fcookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
|
||||
}
|
||||
|
||||
s.ctx.Response().Header.SetCookie(fcookie)
|
||||
fasthttp.ReleaseCookie(fcookie)
|
||||
}
|
||||
|
||||
func (s *Session) deleteCookie() {
|
||||
s.ctx.Request().Header.DelCookie(s.sessions.cfg.Cookie.Name)
|
||||
s.ctx.Response().Header.DelCookie(s.sessions.cfg.Cookie.Name)
|
||||
|
||||
fcookie := fasthttp.AcquireCookie()
|
||||
fcookie.SetKey(s.sessions.cfg.Cookie.Name)
|
||||
fcookie.SetValue(s.id)
|
||||
fcookie.SetPath(s.sessions.cfg.Cookie.Path)
|
||||
fcookie.SetDomain(s.sessions.cfg.Cookie.Domain)
|
||||
fcookie.SetMaxAge(int(s.sessions.cfg.Expiration))
|
||||
fcookie.SetExpire(time.Now().Add(-1 * time.Minute))
|
||||
fcookie.SetSecure(s.sessions.cfg.Cookie.Secure)
|
||||
fcookie.SetHTTPOnly(s.sessions.cfg.Cookie.HTTPOnly)
|
||||
|
||||
switch utils.ToLower(s.sessions.cfg.Cookie.SameSite) {
|
||||
case "strict":
|
||||
fcookie.SetSameSite(fasthttp.CookieSameSiteStrictMode)
|
||||
case "none":
|
||||
fcookie.SetSameSite(fasthttp.CookieSameSiteNoneMode)
|
||||
default:
|
||||
fcookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
|
||||
}
|
||||
|
||||
s.ctx.Response().Header.SetCookie(fcookie)
|
||||
fasthttp.ReleaseCookie(fcookie)
|
||||
}
|
|
@ -0,0 +1,111 @@
|
|||
package sessions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
// Possible values:
|
||||
// - "header:<name>"
|
||||
// - "query:<name>"
|
||||
// - "param:<name>"
|
||||
// - "form:<name>"
|
||||
// - "cookie:<name>"
|
||||
//
|
||||
// Optional. Default value "cookie:_csrf".
|
||||
// TODO: When to override Cookie.Value?
|
||||
KeyLookup string
|
||||
|
||||
// Optional. Session ID generator function.
|
||||
//
|
||||
// Default: utils.UUID
|
||||
KeyGenerator func() string
|
||||
|
||||
// Optional. Cookie to set values on
|
||||
//
|
||||
// NOTE: Value, MaxAge and Expires will be overriden by the session ID and expiration
|
||||
// TODO: Should this be a pointer, if yes why?
|
||||
Cookie fiber.Cookie
|
||||
|
||||
// Allowed session duration
|
||||
//
|
||||
// Optional. Default: 24 hours
|
||||
Expiration time.Duration
|
||||
|
||||
// Store interface
|
||||
// Optional. Default: memory.New
|
||||
Store fiber.Storage
|
||||
}
|
||||
|
||||
var ConfigDefault = Config{
|
||||
Cookie: fiber.Cookie{
|
||||
Value: "session_id",
|
||||
},
|
||||
Expiration: 30 * time.Minute,
|
||||
KeyGenerator: utils.UUID,
|
||||
}
|
||||
|
||||
type Sessions struct {
|
||||
cfg Config
|
||||
}
|
||||
|
||||
func New(config ...Config) *Sessions {
|
||||
cfg := ConfigDefault
|
||||
|
||||
if len(config) > 0 {
|
||||
cfg = config[0]
|
||||
}
|
||||
|
||||
if cfg.Store == nil {
|
||||
cfg.Store = memoryStorage()
|
||||
}
|
||||
|
||||
return &Sessions{
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Sessions) Get(c *fiber.Ctx) *Session {
|
||||
var fresh bool
|
||||
|
||||
// Get ID from cookie
|
||||
id := c.Cookies(s.cfg.Cookie.Name)
|
||||
|
||||
// If no ID exist, create new one
|
||||
if len(id) == 0 {
|
||||
id = s.cfg.KeyGenerator()
|
||||
fresh = true
|
||||
}
|
||||
|
||||
// Create session object
|
||||
sess := &Session{
|
||||
ctx: c,
|
||||
sessions: s,
|
||||
fresh: fresh,
|
||||
db: acquireDB(),
|
||||
id: id,
|
||||
}
|
||||
|
||||
// Fetch existing data
|
||||
if !fresh {
|
||||
raw, err := s.cfg.Store.Get(id)
|
||||
|
||||
// Set data
|
||||
if err == nil {
|
||||
_, err := sess.db.UnmarshalMsg(raw)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Return session object
|
||||
return sess
|
||||
}
|
||||
|
||||
func (s *Sessions) Reset() error {
|
||||
return s.cfg.Store.Reset()
|
||||
}
|
Loading…
Reference in New Issue