mirror of https://github.com/gofiber/fiber.git
Refactor basicauth middleware
parent
827bd38f5a
commit
8c49028a39
|
@ -110,32 +110,37 @@ func New(config Config) fiber.Handler {
|
|||
// Get authorization header
|
||||
auth := c.Get(fiber.HeaderAuthorization)
|
||||
|
||||
// Check if header is valid
|
||||
if len(auth) > 6 && strings.ToLower(auth[:5]) == "basic" {
|
||||
|
||||
// Try to decode
|
||||
if raw, err := base64.StdEncoding.DecodeString(auth[6:]); err == nil {
|
||||
|
||||
// Convert to string
|
||||
cred := utils.UnsafeString(raw)
|
||||
|
||||
// Find semicolumn
|
||||
for i := 0; i < len(cred); i++ {
|
||||
if cred[i] == ':' {
|
||||
// Split into user & pass
|
||||
user := cred[:i]
|
||||
pass := cred[i+1:]
|
||||
|
||||
// If exist & match in Users, we let him pass
|
||||
if cfg.Authorizer(user, pass) {
|
||||
c.Locals(cfg.ContextUsername, user)
|
||||
c.Locals(cfg.ContextPassword, pass)
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Check if the header contains content besides "basic".
|
||||
if len(auth) <= 6 || strings.ToLower(auth[:5]) != "basic" {
|
||||
return cfg.Unauthorized(c)
|
||||
}
|
||||
|
||||
// Decode the header contents
|
||||
raw, err := base64.StdEncoding.DecodeString(auth[6:])
|
||||
if err != nil {
|
||||
return cfg.Unauthorized(c)
|
||||
}
|
||||
|
||||
// Get the credentials
|
||||
creds := utils.UnsafeString(raw)
|
||||
|
||||
// Check if the credentials are in the correct form
|
||||
// which is "username:password".
|
||||
index := strings.Index(creds, ":")
|
||||
if index == -1 {
|
||||
return cfg.Unauthorized(c)
|
||||
}
|
||||
|
||||
// Get the username and password
|
||||
username := creds[:index]
|
||||
password := creds[index+1:]
|
||||
|
||||
if cfg.Authorizer(username, password) {
|
||||
c.Locals(cfg.ContextUsername, username)
|
||||
c.Locals(cfg.ContextPassword, password)
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Authentication failed
|
||||
return cfg.Unauthorized(c)
|
||||
}
|
||||
|
|
|
@ -15,6 +15,8 @@ import (
|
|||
|
||||
// go test -run Test_BasicAuth_Next
|
||||
func Test_BasicAuth_Next(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
Next: func(_ *fiber.Ctx) bool {
|
||||
|
@ -28,16 +30,15 @@ func Test_BasicAuth_Next(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_Middleware_BasicAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
cfg := Config{
|
||||
app.Use(New(Config{
|
||||
Users: map[string]string{
|
||||
"john": "doe",
|
||||
"admin": "123456",
|
||||
},
|
||||
}
|
||||
|
||||
app.Use(New(cfg))
|
||||
}))
|
||||
|
||||
app.Get("/testauth", func(c *fiber.Ctx) error {
|
||||
username := c.Locals("username").(string)
|
||||
|
@ -85,7 +86,6 @@ func Test_Middleware_BasicAuth(t *testing.T) {
|
|||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, tt.statusCode, resp.StatusCode)
|
||||
|
||||
// Only check body if statusCode is 200
|
||||
if tt.statusCode == 200 {
|
||||
utils.AssertEqual(t, fmt.Sprintf("%s%s", tt.username, tt.password), string(body))
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue