175 lines
4.6 KiB
Go

package basicauth
import (
"crypto/sha256"
"crypto/sha512"
"crypto/subtle"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"strconv"
"strings"
"github.com/gofiber/fiber/v3"
"golang.org/x/crypto/bcrypt"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c fiber.Ctx) bool
// Users defines the allowed credentials
//
// Required. Default: map[string]string{}
Users map[string]string
// Authorizer defines a function you can pass
// to check the credentials however you want.
// It will be called with a username, password and
// the current fiber context and is expected to return
// true or false to indicate that the credentials were
// approved or not.
//
// Optional. Default: nil.
Authorizer func(string, string, fiber.Ctx) bool
// Unauthorized defines the response body for unauthorized responses.
// By default it will return with a 401 Unauthorized and the correct WWW-Auth header
//
// Optional. Default: nil
Unauthorized fiber.Handler
// Realm is a string to define realm attribute of BasicAuth.
// the realm identifies the system to authenticate against
// and can be used by clients to save credentials
//
// Optional. Default: "Restricted".
Realm string
// Charset defines the value for the charset parameter in the
// WWW-Authenticate header. According to RFC 7617 clients can use
// this value to interpret credentials correctly.
//
// Optional. Default: "UTF-8".
Charset string
// HeaderLimit specifies the maximum allowed length of the
// Authorization header. Requests exceeding this limit will
// be rejected.
//
// Optional. Default: 8192.
HeaderLimit int
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
Users: map[string]string{},
Realm: "Restricted",
Charset: "UTF-8",
HeaderLimit: 8192,
Authorizer: nil,
Unauthorized: nil,
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
if cfg.Users == nil {
cfg.Users = ConfigDefault.Users
}
if cfg.Realm == "" {
cfg.Realm = ConfigDefault.Realm
}
if cfg.Charset == "" {
cfg.Charset = ConfigDefault.Charset
}
if cfg.HeaderLimit <= 0 {
cfg.HeaderLimit = ConfigDefault.HeaderLimit
}
if cfg.Authorizer == nil {
verifiers := make(map[string]func(string) bool, len(cfg.Users))
for u, hpw := range cfg.Users {
v, err := parseHashedPassword(hpw)
if err != nil {
panic(err)
}
verifiers[u] = v
}
cfg.Authorizer = func(user, pass string, _ fiber.Ctx) bool {
verify, ok := verifiers[user]
return ok && verify(pass)
}
}
if cfg.Unauthorized == nil {
cfg.Unauthorized = func(c fiber.Ctx) error {
header := "Basic realm=" + strconv.Quote(cfg.Realm)
if cfg.Charset != "" {
header += ", charset=" + strconv.Quote(cfg.Charset)
}
c.Set(fiber.HeaderWWWAuthenticate, header)
c.Set(fiber.HeaderCacheControl, "no-store")
c.Set(fiber.HeaderVary, fiber.HeaderAuthorization)
return c.SendStatus(fiber.StatusUnauthorized)
}
}
return cfg
}
func parseHashedPassword(h string) (func(string) bool, error) {
switch {
case strings.HasPrefix(h, "$2"):
hash := []byte(h)
return func(p string) bool {
return bcrypt.CompareHashAndPassword(hash, []byte(p)) == nil
}, nil
case strings.HasPrefix(h, "{SHA512}"):
b, err := base64.StdEncoding.DecodeString(h[len("{SHA512}"):])
if err != nil {
return nil, fmt.Errorf("decode SHA512 password: %w", err)
}
return func(p string) bool {
sum := sha512.Sum512([]byte(p))
return subtle.ConstantTimeCompare(sum[:], b) == 1
}, nil
case strings.HasPrefix(h, "{SHA256}"):
b, err := base64.StdEncoding.DecodeString(h[len("{SHA256}"):])
if err != nil {
return nil, fmt.Errorf("decode SHA256 password: %w", err)
}
return func(p string) bool {
sum := sha256.Sum256([]byte(p))
return subtle.ConstantTimeCompare(sum[:], b) == 1
}, nil
default:
b, err := hex.DecodeString(h)
if err != nil || len(b) != sha256.Size {
if b, err = base64.StdEncoding.DecodeString(h); err != nil {
return nil, fmt.Errorf("decode SHA256 password: %w", err)
}
if len(b) != sha256.Size {
return nil, errors.New("decode SHA256 password: invalid length")
}
}
return func(p string) bool {
sum := sha256.Sum256([]byte(p))
return subtle.ConstantTimeCompare(sum[:], b) == 1
}, nil
}
}