diff --git a/middleware/basicauth/basicauth.go b/middleware/basicauth/basicauth.go index 347d0e35..c0116544 100644 --- a/middleware/basicauth/basicauth.go +++ b/middleware/basicauth/basicauth.go @@ -110,32 +110,37 @@ func New(config Config) fiber.Handler { // Get authorization header auth := c.Get(fiber.HeaderAuthorization) - // Check if header is valid - if len(auth) > 6 && strings.ToLower(auth[:5]) == "basic" { - - // Try to decode - if raw, err := base64.StdEncoding.DecodeString(auth[6:]); err == nil { - - // Convert to string - cred := utils.UnsafeString(raw) - - // Find semicolumn - for i := 0; i < len(cred); i++ { - if cred[i] == ':' { - // Split into user & pass - user := cred[:i] - pass := cred[i+1:] - - // If exist & match in Users, we let him pass - if cfg.Authorizer(user, pass) { - c.Locals(cfg.ContextUsername, user) - c.Locals(cfg.ContextPassword, pass) - return c.Next() - } - } - } - } + // Check if the header contains content besides "basic". + if len(auth) <= 6 || strings.ToLower(auth[:5]) != "basic" { + return cfg.Unauthorized(c) } + + // Decode the header contents + raw, err := base64.StdEncoding.DecodeString(auth[6:]) + if err != nil { + return cfg.Unauthorized(c) + } + + // Get the credentials + creds := utils.UnsafeString(raw) + + // Check if the credentials are in the correct form + // which is "username:password". + index := strings.Index(creds, ":") + if index == -1 { + return cfg.Unauthorized(c) + } + + // Get the username and password + username := creds[:index] + password := creds[index+1:] + + if cfg.Authorizer(username, password) { + c.Locals(cfg.ContextUsername, username) + c.Locals(cfg.ContextPassword, password) + return c.Next() + } + // Authentication failed return cfg.Unauthorized(c) } diff --git a/middleware/basicauth/basicauth_test.go b/middleware/basicauth/basicauth_test.go index 8f0e5c9c..e652a156 100644 --- a/middleware/basicauth/basicauth_test.go +++ b/middleware/basicauth/basicauth_test.go @@ -15,6 +15,8 @@ import ( // go test -run Test_BasicAuth_Next func Test_BasicAuth_Next(t *testing.T) { + t.Parallel() + app := fiber.New() app.Use(New(Config{ Next: func(_ *fiber.Ctx) bool { @@ -28,16 +30,15 @@ func Test_BasicAuth_Next(t *testing.T) { } func Test_Middleware_BasicAuth(t *testing.T) { + t.Parallel() app := fiber.New() - cfg := Config{ + app.Use(New(Config{ Users: map[string]string{ "john": "doe", "admin": "123456", }, - } - - app.Use(New(cfg)) + })) app.Get("/testauth", func(c *fiber.Ctx) error { username := c.Locals("username").(string) @@ -85,7 +86,6 @@ func Test_Middleware_BasicAuth(t *testing.T) { utils.AssertEqual(t, nil, err) utils.AssertEqual(t, tt.statusCode, resp.StatusCode) - // Only check body if statusCode is 200 if tt.statusCode == 200 { utils.AssertEqual(t, fmt.Sprintf("%s%s", tt.username, tt.password), string(body)) }