fiber/middleware/csrf/extractors.go

71 lines
1.8 KiB
Go

package csrf
import (
"errors"
"github.com/gofiber/fiber/v3"
)
var (
ErrMissingHeader = errors.New("missing csrf token in header")
ErrMissingQuery = errors.New("missing csrf token in query")
ErrMissingParam = errors.New("missing csrf token in param")
ErrMissingForm = errors.New("missing csrf token in form")
ErrMissingCookie = errors.New("missing csrf token in cookie")
)
// csrfFromParam returns a function that extracts token from the url param string.
func CsrfFromParam(param string) func(c fiber.Ctx) (string, error) {
return func(c fiber.Ctx) (string, error) {
token := c.Params(param)
if token == "" {
return "", ErrMissingParam
}
return token, nil
}
}
// csrfFromForm returns a function that extracts a token from a multipart-form.
func CsrfFromForm(param string) func(c fiber.Ctx) (string, error) {
return func(c fiber.Ctx) (string, error) {
token := c.FormValue(param)
if token == "" {
return "", ErrMissingForm
}
return token, nil
}
}
// csrfFromCookie returns a function that extracts token from the cookie header.
func CsrfFromCookie(param string) func(c fiber.Ctx) (string, error) {
return func(c fiber.Ctx) (string, error) {
token := c.Cookies(param)
if token == "" {
return "", ErrMissingCookie
}
return token, nil
}
}
// csrfFromHeader returns a function that extracts token from the request header.
func CsrfFromHeader(param string) func(c fiber.Ctx) (string, error) {
return func(c fiber.Ctx) (string, error) {
token := c.Get(param)
if token == "" {
return "", ErrMissingHeader
}
return token, nil
}
}
// csrfFromQuery returns a function that extracts token from the query string.
func CsrfFromQuery(param string) func(c fiber.Ctx) (string, error) {
return func(c fiber.Ctx) (string, error) {
token := c.Query(param)
if token == "" {
return "", ErrMissingQuery
}
return token, nil
}
}