Refactor basicauth middleware

pull/922/head
Konstantinos Lypitkas 2020-10-12 19:22:05 +03:00
parent 827bd38f5a
commit 8c49028a39
2 changed files with 35 additions and 30 deletions

View File

@ -110,32 +110,37 @@ func New(config Config) fiber.Handler {
// Get authorization header
auth := c.Get(fiber.HeaderAuthorization)
// Check if header is valid
if len(auth) > 6 && strings.ToLower(auth[:5]) == "basic" {
// Try to decode
if raw, err := base64.StdEncoding.DecodeString(auth[6:]); err == nil {
// Convert to string
cred := utils.UnsafeString(raw)
// Find semicolumn
for i := 0; i < len(cred); i++ {
if cred[i] == ':' {
// Split into user & pass
user := cred[:i]
pass := cred[i+1:]
// If exist & match in Users, we let him pass
if cfg.Authorizer(user, pass) {
c.Locals(cfg.ContextUsername, user)
c.Locals(cfg.ContextPassword, pass)
return c.Next()
}
}
}
}
// Check if the header contains content besides "basic".
if len(auth) <= 6 || strings.ToLower(auth[:5]) != "basic" {
return cfg.Unauthorized(c)
}
// Decode the header contents
raw, err := base64.StdEncoding.DecodeString(auth[6:])
if err != nil {
return cfg.Unauthorized(c)
}
// Get the credentials
creds := utils.UnsafeString(raw)
// Check if the credentials are in the correct form
// which is "username:password".
index := strings.Index(creds, ":")
if index == -1 {
return cfg.Unauthorized(c)
}
// Get the username and password
username := creds[:index]
password := creds[index+1:]
if cfg.Authorizer(username, password) {
c.Locals(cfg.ContextUsername, username)
c.Locals(cfg.ContextPassword, password)
return c.Next()
}
// Authentication failed
return cfg.Unauthorized(c)
}

View File

@ -15,6 +15,8 @@ import (
// go test -run Test_BasicAuth_Next
func Test_BasicAuth_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(_ *fiber.Ctx) bool {
@ -28,16 +30,15 @@ func Test_BasicAuth_Next(t *testing.T) {
}
func Test_Middleware_BasicAuth(t *testing.T) {
t.Parallel()
app := fiber.New()
cfg := Config{
app.Use(New(Config{
Users: map[string]string{
"john": "doe",
"admin": "123456",
},
}
app.Use(New(cfg))
}))
app.Get("/testauth", func(c *fiber.Ctx) error {
username := c.Locals("username").(string)
@ -85,7 +86,6 @@ func Test_Middleware_BasicAuth(t *testing.T) {
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, tt.statusCode, resp.StatusCode)
// Only check body if statusCode is 200
if tt.statusCode == 200 {
utils.AssertEqual(t, fmt.Sprintf("%s%s", tt.username, tt.password), string(body))
}