mirror of https://github.com/gofiber/fiber.git
wip
parent
f301d39b0f
commit
98bbb40398
|
@ -8,18 +8,17 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v3"
|
||||
"github.com/gofiber/fiber/v3/log"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTokenNotFound = errors.New("csrf token not found")
|
||||
ErrTokenInvalid = errors.New("csrf token invalid")
|
||||
ErrRefererNotFound = errors.New("referer not supplied")
|
||||
ErrRefererInvalid = errors.New("referer invalid")
|
||||
ErrRefererNoMatch = errors.New("referer does not match host and is not a trusted origin")
|
||||
ErrOriginInvalid = errors.New("origin invalid")
|
||||
ErrOriginNoMatch = errors.New("origin does not match host and is not a trusted origin")
|
||||
ErrNotGetStorage = errors.New("unable to retrieve data from CSRF storage")
|
||||
ErrTokenNotFound = errors.New("csrf token not found")
|
||||
ErrTokenInvalid = errors.New("csrf token invalid")
|
||||
ErrRefererNotFound = errors.New("referer not supplied")
|
||||
ErrRefererInvalid = errors.New("referer invalid")
|
||||
ErrRefererNoMatch = errors.New("referer does not match host and is not a trusted origin")
|
||||
ErrOriginInvalid = errors.New("origin invalid")
|
||||
ErrOriginNoMatch = errors.New("origin does not match host and is not a trusted origin")
|
||||
ErrStorageRetrievalFailed = errors.New("unable to retrieve data from CSRF storage")
|
||||
|
||||
errOriginNotFound = errors.New("origin not supplied or is null") // internal error, will not be returned to the user
|
||||
dummyValue = []byte{'+'}
|
||||
|
@ -106,8 +105,11 @@ func New(config ...Config) fiber.Handler {
|
|||
case fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace:
|
||||
cookieToken := c.Cookies(cfg.CookieName)
|
||||
if cookieToken != "" {
|
||||
// In this case, handling error doesn't make sense because we have validations after the switch.
|
||||
raw, _ := getRawFromStorage(c, cookieToken, cfg, sessionManager, storageManager) //nolint:errcheck //the details are in the comment above
|
||||
raw, err := getRawFromStorage(c, cookieToken, cfg, sessionManager, storageManager)
|
||||
if err != nil {
|
||||
println("hereee+" + err.Error())
|
||||
return cfg.ErrorHandler(c, err)
|
||||
}
|
||||
if raw != nil {
|
||||
token = cookieToken // Token is valid, safe to set it
|
||||
}
|
||||
|
@ -151,14 +153,17 @@ func New(config ...Config) fiber.Handler {
|
|||
}
|
||||
|
||||
raw, err := getRawFromStorage(c, extractedToken, cfg, sessionManager, storageManager)
|
||||
if err != nil || raw == nil {
|
||||
log.Error("Failed to retrieve CSRF token: ", err)
|
||||
if err != nil {
|
||||
|
||||
return cfg.ErrorHandler(c, err)
|
||||
} else if raw == nil {
|
||||
|
||||
// If token is not in storage, expire the cookie
|
||||
expireCSRFCookie(c, cfg)
|
||||
// and return an error
|
||||
return cfg.ErrorHandler(c, ErrTokenNotFound)
|
||||
return cfg.ErrorHandler(c, ErrTokenInvalid)
|
||||
}
|
||||
|
||||
if cfg.SingleUseToken {
|
||||
// If token is single use, delete it from storage
|
||||
deleteTokenFromStorage(c, extractedToken, cfg, sessionManager, storageManager)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package csrf
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
@ -1263,7 +1264,6 @@ func Test_CSRF_Cookie_Injection_Exploit(t *testing.T) {
|
|||
ctx.Request.SetRequestURI("/")
|
||||
h(ctx)
|
||||
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
|
||||
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
|
||||
|
||||
// Exploit CSRF token we just injected
|
||||
ctx.Request.Reset()
|
||||
|
@ -1509,3 +1509,67 @@ func Test_CSRF_FromContextMethods_Invalid(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
||||
type mockStorage struct{}
|
||||
|
||||
func (m *mockStorage) Get(key string) ([]byte, error) {
|
||||
return nil, fmt.Errorf("not found")
|
||||
}
|
||||
|
||||
func (m *mockStorage) Set(key string, val []byte, exp time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockStorage) Delete(key string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockStorage) Reset() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockStorage) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func Test_NotGetTokenInSessionStorage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
errHandler := func(c fiber.Ctx, err error) error {
|
||||
require.Equal(t, ErrNotGetStorage.Error(), err.Error())
|
||||
return c.Status(419).Send([]byte(err.Error()))
|
||||
}
|
||||
|
||||
// &session.Store{}.Storage.Set(ConfigDefault.CookieName, "fiber", 300)
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
ErrorHandler: errHandler,
|
||||
Session: &session.Store{
|
||||
Config: session.Config{
|
||||
Storage: &mockStorage{},
|
||||
KeyGenerator: ConfigDefault.KeyGenerator,
|
||||
KeyLookup: ConfigDefault.KeyLookup,
|
||||
Expiration: ConfigDefault.Expiration,
|
||||
CookieSameSite: "Lax",
|
||||
},
|
||||
},
|
||||
}))
|
||||
|
||||
app.Post("/", func(c fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
h := app.Handler()
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
|
||||
ctx.Request.Reset()
|
||||
ctx.Response.Reset()
|
||||
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, "fiber")
|
||||
h(ctx)
|
||||
|
||||
require.Equal(t, 419, ctx.Response.StatusCode())
|
||||
require.Equal(t, "invalid CSRF token", string(ctx.Response.Body()))
|
||||
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package csrf
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v3"
|
||||
|
@ -29,11 +30,13 @@ func newSessionManager(s *session.Store, k string) *sessionManager {
|
|||
func (m *sessionManager) getRaw(c fiber.Ctx, key string, raw []byte) ([]byte, error) {
|
||||
sess, err := m.session.Get(c)
|
||||
if err != nil {
|
||||
log.Warn("csrf: failed to get session: ", err)
|
||||
return nil, ErrTokenNotFound
|
||||
return nil, ErrNotGetStorage
|
||||
}
|
||||
|
||||
fmt.Println("key: ", sess)
|
||||
|
||||
token, ok := sess.Get(m.key).(Token)
|
||||
fmt.Println("key: ", token, ok)
|
||||
if !ok {
|
||||
return nil, ErrTokenInvalid
|
||||
}
|
||||
|
|
|
@ -86,12 +86,26 @@ func (s *Store) Get(c fiber.Ctx) (*Session, error) {
|
|||
sess.id = id
|
||||
sess.fresh = fresh
|
||||
|
||||
// Decode session data if found
|
||||
if rawData != nil {
|
||||
sess.data.Lock()
|
||||
defer sess.data.Unlock()
|
||||
if err := sess.decodeSessionData(rawData); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode session data: %w", err)
|
||||
// Fetch existing data
|
||||
if loadData {
|
||||
raw, err := s.Storage.Get(id)
|
||||
// Unmarshal if we found data
|
||||
switch {
|
||||
case err != nil:
|
||||
return nil, err
|
||||
|
||||
case raw != nil:
|
||||
mux.Lock()
|
||||
defer mux.Unlock()
|
||||
sess.byteBuffer.Write(raw)
|
||||
encCache := gob.NewDecoder(sess.byteBuffer)
|
||||
err := encCache.Decode(&sess.data.Data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode session data: %w", err)
|
||||
}
|
||||
default:
|
||||
// both raw and err is nil, which means id is not in the storage
|
||||
sess.fresh = true
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue