login_source: migrate to GORM and add tests (#6090)

* Use GORM in all write paths

* Migrate to GORM

* Fix lint errors

* Use GORM  to init table

* dbutil: make writer detect error

* Add more tests

* Rename to clearTables

* db: finish adding tests

* osutil: add tests

* Fix load source files path
pull/6092/head
ᴜɴᴋɴᴡᴏɴ 2020-04-11 20:18:05 +08:00 committed by GitHub
parent 76bb647d24
commit 41f56ad05d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 1119 additions and 652 deletions

File diff suppressed because one or more lines are too long

View File

@ -19,9 +19,9 @@ type SecurityProtocol int
// Note: new type must be added at the end of list to maintain compatibility.
const (
SECURITY_PROTOCOL_UNENCRYPTED SecurityProtocol = iota
SECURITY_PROTOCOL_LDAPS
SECURITY_PROTOCOL_START_TLS
SecurityProtocolUnencrypted SecurityProtocol = iota
SecurityProtocolLDAPS
SecurityProtocolStartTLS
)
// Basic LDAP authentication service
@ -144,7 +144,7 @@ func dial(ls *Source) (*ldap.Conn, error) {
ServerName: ls.Host,
InsecureSkipVerify: ls.SkipVerify,
}
if ls.SecurityProtocol == SECURITY_PROTOCOL_LDAPS {
if ls.SecurityProtocol == SecurityProtocolLDAPS {
return ldap.DialTLS("tcp", fmt.Sprintf("%s:%d", ls.Host, ls.Port), tlsCfg)
}
@ -153,7 +153,7 @@ func dial(ls *Source) (*ldap.Conn, error) {
return nil, fmt.Errorf("Dial: %v", err)
}
if ls.SecurityProtocol == SECURITY_PROTOCOL_START_TLS {
if ls.SecurityProtocol == SecurityProtocolStartTLS {
if err = conn.StartTLS(tlsCfg); err != nil {
conn.Close()
return nil, fmt.Errorf("StartTLS: %v", err)

View File

@ -21,8 +21,9 @@ func Test_accessTokens(t *testing.T) {
t.Parallel()
tables := []interface{}{new(AccessToken)}
db := &accessTokens{
DB: initTestDB(t, "accessTokens", new(AccessToken)),
DB: initTestDB(t, "accessTokens", tables...),
}
for _, tc := range []struct {
@ -37,7 +38,7 @@ func Test_accessTokens(t *testing.T) {
} {
t.Run(tc.name, func(t *testing.T) {
t.Cleanup(func() {
err := deleteTables(db.DB, new(AccessToken))
err := clearTables(db.DB, tables...)
if err != nil {
t.Fatal(err)
}
@ -78,14 +79,14 @@ func test_accessTokens_DeleteByID(t *testing.T, db *accessTokens) {
t.Fatal(err)
}
// We should be able to get it back
_, err = db.GetBySHA(token.Sha1)
// Delete a token with mismatched user ID is noop
err = db.DeleteByID(2, token.ID)
if err != nil {
t.Fatal(err)
}
// Delete a token with mismatched user ID is noop
err = db.DeleteByID(2, token.ID)
// We should be able to get it back
_, err = db.GetBySHA(token.Sha1)
if err != nil {
t.Fatal(err)
}

View File

@ -124,7 +124,7 @@ func getLogWriter() (io.Writer, error) {
var tables = []interface{}{
new(AccessToken),
new(LFSObject),
new(LFSObject), new(LoginSource),
}
func Init() error {
@ -167,9 +167,14 @@ func Init() error {
return time.Now().UTC().Truncate(time.Microsecond)
}
sourceFiles, err := loadLoginSourceFiles(filepath.Join(conf.CustomDir(), "conf", "auth.d"))
if err != nil {
return errors.Wrap(err, "load login source files")
}
// Initialize stores, sorted in alphabetical order.
AccessTokens = &accessTokens{DB: db}
LoginSources = &loginSources{DB: db}
LoginSources = &loginSources{DB: db, files: sourceFiles}
LFS = &lfs{DB: db}
Perms = &perms{DB: db}
Repos = &repos{DB: db}

View File

@ -327,39 +327,6 @@ func (err ErrRepoFileAlreadyExist) Error() string {
return fmt.Sprintf("repository file already exists [file_name: %s]", err.FileName)
}
// .____ .__ _________
// | | ____ ____ |__| ____ / _____/ ____ __ _________ ____ ____
// | | / _ \ / ___\| |/ \ \_____ \ / _ \| | \_ __ \_/ ___\/ __ \
// | |__( <_> ) /_/ > | | \ / ( <_> ) | /| | \/\ \__\ ___/
// |_______ \____/\___ /|__|___| / /_______ /\____/|____/ |__| \___ >___ >
// \/ /_____/ \/ \/ \/ \/
type ErrLoginSourceAlreadyExist struct {
Name string
}
func IsErrLoginSourceAlreadyExist(err error) bool {
_, ok := err.(ErrLoginSourceAlreadyExist)
return ok
}
func (err ErrLoginSourceAlreadyExist) Error() string {
return fmt.Sprintf("login source already exists [name: %s]", err.Name)
}
type ErrLoginSourceInUse struct {
ID int64
}
func IsErrLoginSourceInUse(err error) bool {
_, ok := err.(ErrLoginSourceInUse)
return ok
}
func (err ErrLoginSourceInUse) Error() string {
return fmt.Sprintf("login source is still used by some users [id: %d]", err.ID)
}
// ___________
// \__ ___/___ _____ _____
// | |_/ __ \\__ \ / \

View File

@ -6,19 +6,6 @@ package errors
import "fmt"
type LoginSourceNotExist struct {
ID int64
}
func IsLoginSourceNotExist(err error) bool {
_, ok := err.(LoginSourceNotExist)
return ok
}
func (err LoginSourceNotExist) Error() string {
return fmt.Sprintf("login source does not exist [id: %d]", err.ID)
}
type LoginSourceNotActivated struct {
SourceID int64
}
@ -44,4 +31,3 @@ func IsInvalidLoginSourceType(err error) bool {
func (err InvalidLoginSourceType) Error() string {
return fmt.Sprintf("invalid login source type [type: %v]", err.Type)
}

View File

@ -22,8 +22,9 @@ func Test_lfs(t *testing.T) {
t.Parallel()
tables := []interface{}{new(LFSObject)}
db := &lfs{
DB: initTestDB(t, "lfs", new(LFSObject)),
DB: initTestDB(t, "lfs", tables...),
}
for _, tc := range []struct {
@ -36,7 +37,7 @@ func Test_lfs(t *testing.T) {
} {
t.Run(tc.name, func(t *testing.T) {
t.Cleanup(func() {
err := deleteTables(db.DB, new(LFSObject))
err := clearTables(db.DB, tables...)
if err != nil {
t.Fatal(err)
}

View File

@ -10,30 +10,21 @@ import (
"fmt"
"net/smtp"
"net/textproto"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/go-macaron/binding"
"github.com/json-iterator/go"
"github.com/unknwon/com"
"gopkg.in/ini.v1"
log "unknwon.dev/clog/v2"
"xorm.io/core"
"xorm.io/xorm"
"gogs.io/gogs/internal/auth/github"
"gogs.io/gogs/internal/auth/ldap"
"gogs.io/gogs/internal/auth/pam"
"gogs.io/gogs/internal/conf"
"gogs.io/gogs/internal/db/errors"
)
type LoginType int
// Note: new type must append to the end of list to maintain compatibility.
// TODO: Move to authutil.
const (
LoginNotype LoginType = iota
LoginPlain // 1
@ -52,497 +43,24 @@ var LoginNames = map[LoginType]string{
LoginGitHub: "GitHub",
}
var SecurityProtocolNames = map[ldap.SecurityProtocol]string{
ldap.SECURITY_PROTOCOL_UNENCRYPTED: "Unencrypted",
ldap.SECURITY_PROTOCOL_LDAPS: "LDAPS",
ldap.SECURITY_PROTOCOL_START_TLS: "StartTLS",
}
// Ensure structs implemented interface.
var (
_ core.Conversion = &LDAPConfig{}
_ core.Conversion = &SMTPConfig{}
_ core.Conversion = &PAMConfig{}
_ core.Conversion = &GitHubConfig{}
)
// ***********************
// ----- LDAP config -----
// ***********************
type LDAPConfig struct {
*ldap.Source `ini:"config"`
ldap.Source `ini:"config"`
}
func (cfg *LDAPConfig) FromDB(bs []byte) error {
return jsoniter.Unmarshal(bs, &cfg)
}
func (cfg *LDAPConfig) ToDB() ([]byte, error) {
return jsoniter.Marshal(cfg)
var SecurityProtocolNames = map[ldap.SecurityProtocol]string{
ldap.SecurityProtocolUnencrypted: "Unencrypted",
ldap.SecurityProtocolLDAPS: "LDAPS",
ldap.SecurityProtocolStartTLS: "StartTLS",
}
func (cfg *LDAPConfig) SecurityProtocolName() string {
return SecurityProtocolNames[cfg.SecurityProtocol]
}
type SMTPConfig struct {
Auth string
Host string
Port int
AllowedDomains string `xorm:"TEXT"`
TLS bool `ini:"tls"`
SkipVerify bool
}
func (cfg *SMTPConfig) FromDB(bs []byte) error {
return jsoniter.Unmarshal(bs, cfg)
}
func (cfg *SMTPConfig) ToDB() ([]byte, error) {
return jsoniter.Marshal(cfg)
}
type PAMConfig struct {
ServiceName string // PAM service (e.g. system-auth)
}
func (cfg *PAMConfig) FromDB(bs []byte) error {
return jsoniter.Unmarshal(bs, &cfg)
}
func (cfg *PAMConfig) ToDB() ([]byte, error) {
return jsoniter.Marshal(cfg)
}
type GitHubConfig struct {
APIEndpoint string // GitHub service (e.g. https://api.github.com/)
}
func (cfg *GitHubConfig) FromDB(bs []byte) error {
return jsoniter.Unmarshal(bs, &cfg)
}
func (cfg *GitHubConfig) ToDB() ([]byte, error) {
return jsoniter.Marshal(cfg)
}
// AuthSourceFile contains information of an authentication source file.
type AuthSourceFile struct {
abspath string
file *ini.File
}
// SetGeneral sets new value to the given key in the general (default) section.
func (f *AuthSourceFile) SetGeneral(name, value string) {
f.file.Section("").Key(name).SetValue(value)
}
// SetConfig sets new values to the "config" section.
func (f *AuthSourceFile) SetConfig(cfg core.Conversion) error {
return f.file.Section("config").ReflectFrom(cfg)
}
// Save writes updates into file system.
func (f *AuthSourceFile) Save() error {
return f.file.SaveTo(f.abspath)
}
// LoginSource represents an external way for authorizing users.
type LoginSource struct {
ID int64
Type LoginType
Name string `xorm:"UNIQUE"`
IsActived bool `xorm:"NOT NULL DEFAULT false"`
IsDefault bool `xorm:"DEFAULT false"`
Cfg core.Conversion `xorm:"TEXT" gorm:"COLUMN:remove-me-when-migrated-to-gorm"`
RawCfg string `xorm:"-" gorm:"COLUMN:cfg"` // TODO: Remove me when migrated to GORM.
Created time.Time `xorm:"-" json:"-"`
CreatedUnix int64
Updated time.Time `xorm:"-" json:"-"`
UpdatedUnix int64
LocalFile *AuthSourceFile `xorm:"-" json:"-"`
}
func (s *LoginSource) BeforeInsert() {
s.CreatedUnix = time.Now().Unix()
s.UpdatedUnix = s.CreatedUnix
}
func (s *LoginSource) BeforeUpdate() {
s.UpdatedUnix = time.Now().Unix()
}
// Cell2Int64 converts a xorm.Cell type to int64,
// and handles possible irregular cases.
func Cell2Int64(val xorm.Cell) int64 {
switch (*val).(type) {
case []uint8:
log.Trace("Cell2Int64 ([]uint8): %v", *val)
return com.StrTo(string((*val).([]uint8))).MustInt64()
}
return (*val).(int64)
}
func (s *LoginSource) BeforeSet(colName string, val xorm.Cell) {
switch colName {
case "type":
switch LoginType(Cell2Int64(val)) {
case LoginLDAP, LoginDLDAP:
s.Cfg = new(LDAPConfig)
case LoginSMTP:
s.Cfg = new(SMTPConfig)
case LoginPAM:
s.Cfg = new(PAMConfig)
case LoginGitHub:
s.Cfg = new(GitHubConfig)
default:
panic("unrecognized login source type: " + com.ToStr(*val))
}
}
}
func (s *LoginSource) AfterSet(colName string, _ xorm.Cell) {
switch colName {
case "created_unix":
s.Created = time.Unix(s.CreatedUnix, 0).Local()
case "updated_unix":
s.Updated = time.Unix(s.UpdatedUnix, 0).Local()
}
}
// NOTE: This is a GORM query hook.
func (s *LoginSource) AfterFind() error {
switch s.Type {
case LoginLDAP, LoginDLDAP:
s.Cfg = new(LDAPConfig)
case LoginSMTP:
s.Cfg = new(SMTPConfig)
case LoginPAM:
s.Cfg = new(PAMConfig)
case LoginGitHub:
s.Cfg = new(GitHubConfig)
default:
return fmt.Errorf("unrecognized login source type: %v", s.Type)
}
return jsoniter.UnmarshalFromString(s.RawCfg, s.Cfg)
}
func (s *LoginSource) TypeName() string {
return LoginNames[s.Type]
}
func (s *LoginSource) IsLDAP() bool {
return s.Type == LoginLDAP
}
func (s *LoginSource) IsDLDAP() bool {
return s.Type == LoginDLDAP
}
func (s *LoginSource) IsSMTP() bool {
return s.Type == LoginSMTP
}
func (s *LoginSource) IsPAM() bool {
return s.Type == LoginPAM
}
func (s *LoginSource) IsGitHub() bool {
return s.Type == LoginGitHub
}
func (s *LoginSource) HasTLS() bool {
return ((s.IsLDAP() || s.IsDLDAP()) &&
s.LDAP().SecurityProtocol > ldap.SECURITY_PROTOCOL_UNENCRYPTED) ||
s.IsSMTP()
}
func (s *LoginSource) UseTLS() bool {
switch s.Type {
case LoginLDAP, LoginDLDAP:
return s.LDAP().SecurityProtocol != ldap.SECURITY_PROTOCOL_UNENCRYPTED
case LoginSMTP:
return s.SMTP().TLS
}
return false
}
func (s *LoginSource) SkipVerify() bool {
switch s.Type {
case LoginLDAP, LoginDLDAP:
return s.LDAP().SkipVerify
case LoginSMTP:
return s.SMTP().SkipVerify
}
return false
}
func (s *LoginSource) LDAP() *LDAPConfig {
return s.Cfg.(*LDAPConfig)
}
func (s *LoginSource) SMTP() *SMTPConfig {
return s.Cfg.(*SMTPConfig)
}
func (s *LoginSource) PAM() *PAMConfig {
return s.Cfg.(*PAMConfig)
}
func (s *LoginSource) GitHub() *GitHubConfig {
return s.Cfg.(*GitHubConfig)
}
func CreateLoginSource(source *LoginSource) error {
has, err := x.Get(&LoginSource{Name: source.Name})
if err != nil {
return err
} else if has {
return ErrLoginSourceAlreadyExist{source.Name}
}
_, err = x.Insert(source)
if err != nil {
return err
} else if source.IsDefault {
return ResetNonDefaultLoginSources(source)
}
return nil
}
// ListLoginSources returns all login sources defined.
func ListLoginSources() ([]*LoginSource, error) {
sources := make([]*LoginSource, 0, 2)
if err := x.Find(&sources); err != nil {
return nil, err
}
return append(sources, localLoginSources.List()...), nil
}
// ActivatedLoginSources returns login sources that are currently activated.
func ActivatedLoginSources() ([]*LoginSource, error) {
sources := make([]*LoginSource, 0, 2)
if err := x.Where("is_actived = ?", true).Find(&sources); err != nil {
return nil, fmt.Errorf("find activated login sources: %v", err)
}
return append(sources, localLoginSources.ActivatedList()...), nil
}
// ResetNonDefaultLoginSources clean other default source flag
func ResetNonDefaultLoginSources(source *LoginSource) error {
// update changes to DB
if _, err := x.NotIn("id", []int64{source.ID}).Cols("is_default").Update(&LoginSource{IsDefault: false}); err != nil {
return err
}
// write changes to local authentications
for i := range localLoginSources.sources {
if localLoginSources.sources[i].LocalFile != nil && localLoginSources.sources[i].ID != source.ID {
localLoginSources.sources[i].LocalFile.SetGeneral("is_default", "false")
if err := localLoginSources.sources[i].LocalFile.SetConfig(source.Cfg); err != nil {
return fmt.Errorf("LocalFile.SetConfig: %v", err)
} else if err = localLoginSources.sources[i].LocalFile.Save(); err != nil {
return fmt.Errorf("LocalFile.Save: %v", err)
}
}
}
// flush memory so that web page can show the same behaviors
localLoginSources.UpdateLoginSource(source)
return nil
}
// UpdateLoginSource updates information of login source to database or local file.
func UpdateLoginSource(source *LoginSource) error {
if source.LocalFile == nil {
if _, err := x.Id(source.ID).AllCols().Update(source); err != nil {
return err
} else {
return ResetNonDefaultLoginSources(source)
}
}
source.LocalFile.SetGeneral("name", source.Name)
source.LocalFile.SetGeneral("is_activated", com.ToStr(source.IsActived))
source.LocalFile.SetGeneral("is_default", com.ToStr(source.IsDefault))
if err := source.LocalFile.SetConfig(source.Cfg); err != nil {
return fmt.Errorf("LocalFile.SetConfig: %v", err)
} else if err = source.LocalFile.Save(); err != nil {
return fmt.Errorf("LocalFile.Save: %v", err)
}
return ResetNonDefaultLoginSources(source)
}
func DeleteSource(source *LoginSource) error {
count, err := x.Count(&User{LoginSource: source.ID})
if err != nil {
return err
} else if count > 0 {
return ErrLoginSourceInUse{source.ID}
}
_, err = x.Id(source.ID).Delete(new(LoginSource))
return err
}
// CountLoginSources returns total number of login sources.
func CountLoginSources() int64 {
count, _ := x.Count(new(LoginSource))
return count + int64(localLoginSources.Len())
}
// LocalLoginSources contains authentication sources configured and loaded from local files.
// Calling its methods is thread-safe; otherwise, please maintain the mutex accordingly.
type LocalLoginSources struct {
sync.RWMutex
sources []*LoginSource
}
func (s *LocalLoginSources) Len() int {
return len(s.sources)
}
// List returns full clone of login sources.
func (s *LocalLoginSources) List() []*LoginSource {
s.RLock()
defer s.RUnlock()
list := make([]*LoginSource, s.Len())
for i := range s.sources {
list[i] = &LoginSource{}
*list[i] = *s.sources[i]
}
return list
}
// ActivatedList returns clone of activated login sources.
func (s *LocalLoginSources) ActivatedList() []*LoginSource {
s.RLock()
defer s.RUnlock()
list := make([]*LoginSource, 0, 2)
for i := range s.sources {
if !s.sources[i].IsActived {
continue
}
source := &LoginSource{}
*source = *s.sources[i]
list = append(list, source)
}
return list
}
// GetLoginSourceByID returns a clone of login source by given ID.
func (s *LocalLoginSources) GetLoginSourceByID(id int64) (*LoginSource, error) {
s.RLock()
defer s.RUnlock()
for i := range s.sources {
if s.sources[i].ID == id {
source := &LoginSource{}
*source = *s.sources[i]
return source, nil
}
}
return nil, errors.LoginSourceNotExist{ID: id}
}
// UpdateLoginSource updates in-memory copy of the authentication source.
func (s *LocalLoginSources) UpdateLoginSource(source *LoginSource) {
s.Lock()
defer s.Unlock()
source.Updated = time.Now()
for i := range s.sources {
if s.sources[i].ID == source.ID {
*s.sources[i] = *source
} else if source.IsDefault {
s.sources[i].IsDefault = false
}
}
}
var localLoginSources = &LocalLoginSources{}
// LoadAuthSources loads authentication sources from local files
// and converts them into login sources.
func LoadAuthSources() {
authdPath := filepath.Join(conf.CustomDir(), "conf", "auth.d")
if !com.IsDir(authdPath) {
return
}
paths, err := com.GetFileListBySuffix(authdPath, ".conf")
if err != nil {
log.Fatal("Failed to list authentication sources: %v", err)
}
localLoginSources.sources = make([]*LoginSource, 0, len(paths))
for _, fpath := range paths {
authSource, err := ini.Load(fpath)
if err != nil {
log.Fatal("Failed to load authentication source: %v", err)
}
authSource.NameMapper = ini.TitleUnderscore
// Set general attributes
s := authSource.Section("")
loginSource := &LoginSource{
ID: s.Key("id").MustInt64(),
Name: s.Key("name").String(),
IsActived: s.Key("is_activated").MustBool(),
IsDefault: s.Key("is_default").MustBool(),
LocalFile: &AuthSourceFile{
abspath: fpath,
file: authSource,
},
}
fi, err := os.Stat(fpath)
if err != nil {
log.Fatal("Failed to load authentication source: %v", err)
}
loginSource.Updated = fi.ModTime()
// Parse authentication source file
authType := s.Key("type").String()
switch authType {
case "ldap_bind_dn":
loginSource.Type = LoginLDAP
loginSource.Cfg = &LDAPConfig{}
case "ldap_simple_auth":
loginSource.Type = LoginDLDAP
loginSource.Cfg = &LDAPConfig{}
case "smtp":
loginSource.Type = LoginSMTP
loginSource.Cfg = &SMTPConfig{}
case "pam":
loginSource.Type = LoginPAM
loginSource.Cfg = &PAMConfig{}
case "github":
loginSource.Type = LoginGitHub
loginSource.Cfg = &GitHubConfig{}
default:
log.Fatal("Failed to load authentication source: unknown type '%s'", authType)
}
if err = authSource.Section("config").MapTo(loginSource.Cfg); err != nil {
log.Fatal("Failed to parse authentication source 'config': %v", err)
}
localLoginSources.sources = append(localLoginSources.sources, loginSource)
}
}
// .____ ________ _____ __________
// | | \______ \ / _ \\______ \
// | | | | \ / /_\ \| ___/
// | |___ | ` \/ | \ |
// |_______ \/_______ /\____|__ /____|
// \/ \/ \/
func composeFullName(firstname, surname, username string) string {
switch {
case len(firstname) == 0 && len(surname) == 0:
@ -559,7 +77,7 @@ func composeFullName(firstname, surname, username string) string {
// LoginViaLDAP queries if login/password is valid against the LDAP directory pool,
// and create a local user if success when enabled.
func LoginViaLDAP(login, password string, source *LoginSource, autoRegister bool) (*User, error) {
username, fn, sn, mail, isAdmin, succeed := source.Cfg.(*LDAPConfig).SearchEntry(login, password, source.Type == LoginDLDAP)
username, fn, sn, mail, isAdmin, succeed := source.Config.(*LDAPConfig).SearchEntry(login, password, source.Type == LoginDLDAP)
if !succeed {
// User not in LDAP, do nothing
return nil, ErrUserNotExist{args: map[string]interface{}{"login": login}}
@ -606,12 +124,18 @@ func LoginViaLDAP(login, password string, source *LoginSource, autoRegister bool
return user, CreateUser(user)
}
// _________ __________________________
// / _____/ / \__ ___/\______ \
// \_____ \ / \ / \| | | ___/
// / \/ Y \ | | |
// /_______ /\____|__ /____| |____|
// \/ \/
// ***********************
// ----- SMTP config -----
// ***********************
type SMTPConfig struct {
Auth string
Host string
Port int
AllowedDomains string
TLS bool `ini:"tls"`
SkipVerify bool
}
type smtpLoginAuth struct {
username, password string
@ -634,11 +158,11 @@ func (auth *smtpLoginAuth) Next(fromServer []byte, more bool) ([]byte, error) {
}
const (
SMTP_PLAIN = "PLAIN"
SMTP_LOGIN = "LOGIN"
SMTPPlain = "PLAIN"
SMTPLogin = "LOGIN"
)
var SMTPAuths = []string{SMTP_PLAIN, SMTP_LOGIN}
var SMTPAuths = []string{SMTPPlain, SMTPLogin}
func SMTPAuth(a smtp.Auth, cfg *SMTPConfig) error {
c, err := smtp.Dial(fmt.Sprintf("%s:%d", cfg.Host, cfg.Port))
@ -687,9 +211,9 @@ func LoginViaSMTP(login, password string, sourceID int64, cfg *SMTPConfig, autoR
}
var auth smtp.Auth
if cfg.Auth == SMTP_PLAIN {
if cfg.Auth == SMTPPlain {
auth = smtp.PlainAuth("", login, password, cfg.Host)
} else if cfg.Auth == SMTP_LOGIN {
} else if cfg.Auth == SMTPLogin {
auth = &smtpLoginAuth{login, password}
} else {
return nil, errors.New("Unsupported SMTP authentication type")
@ -729,12 +253,14 @@ func LoginViaSMTP(login, password string, sourceID int64, cfg *SMTPConfig, autoR
return user, CreateUser(user)
}
// __________ _____ _____
// \______ \/ _ \ / \
// | ___/ /_\ \ / \ / \
// | | / | \/ Y \
// |____| \____|__ /\____|__ /
// \/ \/
// **********************
// ----- PAM config -----
// **********************
type PAMConfig struct {
// The name of the PAM service, e.g. system-auth.
ServiceName string
}
// LoginViaPAM queries if login/password is valid against the PAM,
// and create a local user if success when enabled.
@ -763,12 +289,14 @@ func LoginViaPAM(login, password string, sourceID int64, cfg *PAMConfig, autoReg
return user, CreateUser(user)
}
// ________.__ __ ___ ___ ___.
// / _____/|__|/ |_ / | \ __ _\_ |__
// / \ ___| \ __\/ ~ \ | \ __ \
// \ \_\ \ || | \ Y / | / \_\ \
// \______ /__||__| \___|_ /|____/|___ /
// \/ \/ \/
// *************************
// ----- GitHub config -----
// *************************
type GitHubConfig struct {
// the GitHub service endpoint, e.g. https://api.github.com/.
APIEndpoint string
}
func LoginViaGitHub(login, password string, sourceID int64, cfg *GitHubConfig, autoRegister bool) (*User, error) {
fullname, email, url, location, err := github.Authenticate(cfg.APIEndpoint, login, password)
@ -807,11 +335,11 @@ func authenticateViaLoginSource(source *LoginSource, login, password string, aut
case LoginLDAP, LoginDLDAP:
return LoginViaLDAP(login, password, source, autoRegister)
case LoginSMTP:
return LoginViaSMTP(login, password, source.ID, source.Cfg.(*SMTPConfig), autoRegister)
return LoginViaSMTP(login, password, source.ID, source.Config.(*SMTPConfig), autoRegister)
case LoginPAM:
return LoginViaPAM(login, password, source.ID, source.Cfg.(*PAMConfig), autoRegister)
return LoginViaPAM(login, password, source.ID, source.Config.(*PAMConfig), autoRegister)
case LoginGitHub:
return LoginViaGitHub(login, password, source.ID, source.Cfg.(*GitHubConfig), autoRegister)
return LoginViaGitHub(login, password, source.ID, source.Config.(*GitHubConfig), autoRegister)
}
return nil, errors.InvalidLoginSourceType{Type: source.Type}

View File

@ -0,0 +1,212 @@
// Copyright 2020 The Gogs Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package db
import (
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"github.com/jinzhu/gorm"
"github.com/pkg/errors"
"gopkg.in/ini.v1"
"gogs.io/gogs/internal/errutil"
"gogs.io/gogs/internal/osutil"
)
// loginSourceFilesStore is the in-memory interface for login source files stored on file system.
//
// NOTE: All methods are sorted in alphabetical order.
type loginSourceFilesStore interface {
// GetByID returns a clone of login source by given ID.
GetByID(id int64) (*LoginSource, error)
// Len returns number of login sources.
Len() int
// List returns a list of login sources filtered by options.
List(opts ListLoginSourceOpts) []*LoginSource
// Update updates in-memory copy of the authentication source.
Update(source *LoginSource)
}
var _ loginSourceFilesStore = (*loginSourceFiles)(nil)
// loginSourceFiles contains authentication sources configured and loaded from local files.
type loginSourceFiles struct {
sync.RWMutex
sources []*LoginSource
}
var _ errutil.NotFound = (*ErrLoginSourceNotExist)(nil)
type ErrLoginSourceNotExist struct {
args errutil.Args
}
func IsErrLoginSourceNotExist(err error) bool {
_, ok := err.(ErrLoginSourceNotExist)
return ok
}
func (err ErrLoginSourceNotExist) Error() string {
return fmt.Sprintf("login source does not exist: %v", err.args)
}
func (ErrLoginSourceNotExist) NotFound() bool {
return true
}
func (s *loginSourceFiles) GetByID(id int64) (*LoginSource, error) {
s.RLock()
defer s.RUnlock()
for _, source := range s.sources {
if source.ID == id {
return source, nil
}
}
return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
}
func (s *loginSourceFiles) Len() int {
s.RLock()
defer s.RUnlock()
return len(s.sources)
}
func (s *loginSourceFiles) List(opts ListLoginSourceOpts) []*LoginSource {
s.RLock()
defer s.RUnlock()
list := make([]*LoginSource, 0, s.Len())
for _, source := range s.sources {
if opts.OnlyActivated && !source.IsActived {
continue
}
list = append(list, source)
}
return list
}
func (s *loginSourceFiles) Update(source *LoginSource) {
s.Lock()
defer s.Unlock()
source.Updated = gorm.NowFunc()
for _, old := range s.sources {
if old.ID == source.ID {
*old = *source
} else if source.IsDefault {
old.IsDefault = false
}
}
}
// loadLoginSourceFiles loads login sources from file system.
func loadLoginSourceFiles(authdPath string) (loginSourceFilesStore, error) {
if !osutil.IsDir(authdPath) {
return &loginSourceFiles{}, nil
}
store := &loginSourceFiles{}
return store, filepath.Walk(authdPath, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if path == authdPath || !strings.HasSuffix(path, ".conf") {
return nil
} else if info.IsDir() {
return filepath.SkipDir
}
authSource, err := ini.Load(path)
if err != nil {
return errors.Wrap(err, "load file")
}
authSource.NameMapper = ini.TitleUnderscore
// Set general attributes
s := authSource.Section("")
loginSource := &LoginSource{
ID: s.Key("id").MustInt64(),
Name: s.Key("name").String(),
IsActived: s.Key("is_activated").MustBool(),
IsDefault: s.Key("is_default").MustBool(),
File: &loginSourceFile{
path: path,
file: authSource,
},
}
fi, err := os.Stat(path)
if err != nil {
return errors.Wrap(err, "stat file")
}
loginSource.Updated = fi.ModTime()
// Parse authentication source file
authType := s.Key("type").String()
switch authType {
case "ldap_bind_dn":
loginSource.Type = LoginLDAP
loginSource.Config = &LDAPConfig{}
case "ldap_simple_auth":
loginSource.Type = LoginDLDAP
loginSource.Config = &LDAPConfig{}
case "smtp":
loginSource.Type = LoginSMTP
loginSource.Config = &SMTPConfig{}
case "pam":
loginSource.Type = LoginPAM
loginSource.Config = &PAMConfig{}
case "github":
loginSource.Type = LoginGitHub
loginSource.Config = &GitHubConfig{}
default:
return fmt.Errorf("unknown type %q", authType)
}
if err = authSource.Section("config").MapTo(loginSource.Config); err != nil {
return errors.Wrap(err, `map "config" section`)
}
store.sources = append(store.sources, loginSource)
return nil
})
}
// loginSourceFileStore is the persistent interface for a login source file.
type loginSourceFileStore interface {
// SetGeneral sets new value to the given key in the general (default) section.
SetGeneral(name, value string)
// SetConfig sets new values to the "config" section.
SetConfig(cfg interface{}) error
// Save persists values to file system.
Save() error
}
var _ loginSourceFileStore = (*loginSourceFile)(nil)
type loginSourceFile struct {
path string
file *ini.File
}
func (f *loginSourceFile) SetGeneral(name, value string) {
f.file.Section("").Key(name).SetValue(value)
}
func (f *loginSourceFile) SetConfig(cfg interface{}) error {
return f.file.Section("config").ReflectFrom(cfg)
}
func (f *loginSourceFile) Save() error {
return f.file.SaveTo(f.path)
}

View File

@ -5,22 +5,242 @@
package db
import (
"fmt"
"strconv"
"time"
"github.com/jinzhu/gorm"
jsoniter "github.com/json-iterator/go"
"github.com/pkg/errors"
"gogs.io/gogs/internal/auth/ldap"
"gogs.io/gogs/internal/errutil"
)
// LoginSourcesStore is the persistent interface for login sources.
//
// NOTE: All methods are sorted in alphabetical order.
type LoginSourcesStore interface {
// Create creates a new login source and persist to database.
// It returns ErrLoginSourceAlreadyExist when a login source with same name already exists.
Create(opts CreateLoginSourceOpts) (*LoginSource, error)
// Count returns the total number of login sources.
Count() int64
// DeleteByID deletes a login source by given ID.
// It returns ErrLoginSourceInUse if at least one user is associated with the login source.
DeleteByID(id int64) error
// GetByID returns the login source with given ID.
// It returns ErrLoginSourceNotExist when not found.
GetByID(id int64) (*LoginSource, error)
// List returns a list of login sources filtered by options.
List(opts ListLoginSourceOpts) ([]*LoginSource, error)
// ResetNonDefault clears default flag for all the other login sources.
ResetNonDefault(source *LoginSource) error
// Save persists all values of given login source to database or local file.
// The Updated field is set to current time automatically.
Save(t *LoginSource) error
}
var LoginSources LoginSourcesStore
// LoginSource represents an external way for authorizing users.
type LoginSource struct {
ID int64
Type LoginType
Name string `xorm:"UNIQUE"`
IsActived bool `xorm:"NOT NULL DEFAULT false"`
IsDefault bool `xorm:"DEFAULT false"`
Config interface{} `xorm:"-" gorm:"-"`
RawConfig string `xorm:"TEXT cfg" gorm:"COLUMN:cfg"`
Created time.Time `xorm:"-" gorm:"-" json:"-"`
CreatedUnix int64
Updated time.Time `xorm:"-" gorm:"-" json:"-"`
UpdatedUnix int64
File loginSourceFileStore `xorm:"-" gorm:"-" json:"-"`
}
// NOTE: This is a GORM save hook.
func (s *LoginSource) BeforeSave() (err error) {
s.RawConfig, err = jsoniter.MarshalToString(s.Config)
return err
}
// NOTE: This is a GORM create hook.
func (s *LoginSource) BeforeCreate() {
s.CreatedUnix = gorm.NowFunc().Unix()
s.UpdatedUnix = s.CreatedUnix
}
// NOTE: This is a GORM update hook.
func (s *LoginSource) BeforeUpdate() {
s.UpdatedUnix = gorm.NowFunc().Unix()
}
// NOTE: This is a GORM query hook.
func (s *LoginSource) AfterFind() error {
s.Created = time.Unix(s.CreatedUnix, 0).Local()
s.Updated = time.Unix(s.UpdatedUnix, 0).Local()
switch s.Type {
case LoginLDAP, LoginDLDAP:
s.Config = new(LDAPConfig)
case LoginSMTP:
s.Config = new(SMTPConfig)
case LoginPAM:
s.Config = new(PAMConfig)
case LoginGitHub:
s.Config = new(GitHubConfig)
default:
return fmt.Errorf("unrecognized login source type: %v", s.Type)
}
return jsoniter.UnmarshalFromString(s.RawConfig, s.Config)
}
func (s *LoginSource) TypeName() string {
return LoginNames[s.Type]
}
func (s *LoginSource) IsLDAP() bool {
return s.Type == LoginLDAP
}
func (s *LoginSource) IsDLDAP() bool {
return s.Type == LoginDLDAP
}
func (s *LoginSource) IsSMTP() bool {
return s.Type == LoginSMTP
}
func (s *LoginSource) IsPAM() bool {
return s.Type == LoginPAM
}
func (s *LoginSource) IsGitHub() bool {
return s.Type == LoginGitHub
}
func (s *LoginSource) HasTLS() bool {
return ((s.IsLDAP() || s.IsDLDAP()) &&
s.LDAP().SecurityProtocol > ldap.SecurityProtocolUnencrypted) ||
s.IsSMTP()
}
func (s *LoginSource) UseTLS() bool {
switch s.Type {
case LoginLDAP, LoginDLDAP:
return s.LDAP().SecurityProtocol != ldap.SecurityProtocolUnencrypted
case LoginSMTP:
return s.SMTP().TLS
}
return false
}
func (s *LoginSource) SkipVerify() bool {
switch s.Type {
case LoginLDAP, LoginDLDAP:
return s.LDAP().SkipVerify
case LoginSMTP:
return s.SMTP().SkipVerify
}
return false
}
func (s *LoginSource) LDAP() *LDAPConfig {
return s.Config.(*LDAPConfig)
}
func (s *LoginSource) SMTP() *SMTPConfig {
return s.Config.(*SMTPConfig)
}
func (s *LoginSource) PAM() *PAMConfig {
return s.Config.(*PAMConfig)
}
func (s *LoginSource) GitHub() *GitHubConfig {
return s.Config.(*GitHubConfig)
}
var _ LoginSourcesStore = (*loginSources)(nil)
type loginSources struct {
*gorm.DB
files loginSourceFilesStore
}
type CreateLoginSourceOpts struct {
Type LoginType
Name string
Activated bool
Default bool
Config interface{}
}
type ErrLoginSourceAlreadyExist struct {
args errutil.Args
}
func IsErrLoginSourceAlreadyExist(err error) bool {
_, ok := err.(ErrLoginSourceAlreadyExist)
return ok
}
func (err ErrLoginSourceAlreadyExist) Error() string {
return fmt.Sprintf("login source already exists: %v", err.args)
}
func (db *loginSources) Create(opts CreateLoginSourceOpts) (*LoginSource, error) {
err := db.Where("name = ?", opts.Name).First(new(LoginSource)).Error
if err == nil {
return nil, ErrLoginSourceAlreadyExist{args: errutil.Args{"name": opts.Name}}
} else if !gorm.IsRecordNotFoundError(err) {
return nil, err
}
source := &LoginSource{
Type: opts.Type,
Name: opts.Name,
IsActived: opts.Activated,
IsDefault: opts.Default,
Config: opts.Config,
}
return source, db.DB.Create(source).Error
}
func (db *loginSources) Count() int64 {
var count int64
db.Model(new(LoginSource)).Count(&count)
return count + int64(db.files.Len())
}
type ErrLoginSourceInUse struct {
args errutil.Args
}
func IsErrLoginSourceInUse(err error) bool {
_, ok := err.(ErrLoginSourceInUse)
return ok
}
func (err ErrLoginSourceInUse) Error() string {
return fmt.Sprintf("login source is still used by some users: %v", err.args)
}
func (db *loginSources) DeleteByID(id int64) error {
var count int64
err := db.Model(new(User)).Where("login_source = ?", id).Count(&count).Error
if err != nil {
return err
} else if count > 0 {
return ErrLoginSourceInUse{args: errutil.Args{"id": id}}
}
return db.Where("id = ?", id).Delete(new(LoginSource)).Error
}
func (db *loginSources) GetByID(id int64) (*LoginSource, error) {
@ -28,9 +248,63 @@ func (db *loginSources) GetByID(id int64) (*LoginSource, error) {
err := db.Where("id = ?", id).First(source).Error
if err != nil {
if gorm.IsRecordNotFoundError(err) {
return localLoginSources.GetLoginSourceByID(id)
return db.files.GetByID(id)
}
return nil, err
}
return source, nil
}
type ListLoginSourceOpts struct {
// Whether to only include activated login sources.
OnlyActivated bool
}
func (db *loginSources) List(opts ListLoginSourceOpts) ([]*LoginSource, error) {
var sources []*LoginSource
query := db.Order("id ASC")
if opts.OnlyActivated {
query = query.Where("is_actived = ?", true)
}
err := query.Find(&sources).Error
if err != nil {
return nil, err
}
return append(sources, db.files.List(opts)...), nil
}
func (db *loginSources) ResetNonDefault(dflt *LoginSource) error {
err := db.Model(new(LoginSource)).Where("id != ?", dflt.ID).Updates(map[string]interface{}{"is_default": false}).Error
if err != nil {
return err
}
for _, source := range db.files.List(ListLoginSourceOpts{}) {
if source.File != nil && source.ID != dflt.ID {
source.File.SetGeneral("is_default", "false")
if err = source.File.Save(); err != nil {
return errors.Wrap(err, "save file")
}
}
}
db.files.Update(dflt)
return nil
}
func (db *loginSources) Save(source *LoginSource) error {
if source.File == nil {
return db.DB.Save(source).Error
}
source.File.SetGeneral("name", source.Name)
source.File.SetGeneral("is_activated", strconv.FormatBool(source.IsActived))
source.File.SetGeneral("is_default", strconv.FormatBool(source.IsDefault))
if err := source.File.SetConfig(source.Config); err != nil {
return errors.Wrap(err, "set config")
} else if err = source.File.Save(); err != nil {
return errors.Wrap(err, "save file")
}
return nil
}

View File

@ -0,0 +1,389 @@
// Copyright 2020 The Gogs Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package db
import (
"testing"
"time"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
"gogs.io/gogs/internal/errutil"
)
func Test_loginSources(t *testing.T) {
if testing.Short() {
t.Skip()
}
t.Parallel()
tables := []interface{}{new(LoginSource), new(User)}
db := &loginSources{
DB: initTestDB(t, "loginSources", tables...),
}
for _, tc := range []struct {
name string
test func(*testing.T, *loginSources)
}{
{"Create", test_loginSources_Create},
{"Count", test_loginSources_Count},
{"DeleteByID", test_loginSources_DeleteByID},
{"GetByID", test_loginSources_GetByID},
{"List", test_loginSources_List},
{"ResetNonDefault", test_loginSources_ResetNonDefault},
{"Save", test_loginSources_Save},
} {
t.Run(tc.name, func(t *testing.T) {
t.Cleanup(func() {
err := clearTables(db.DB, tables...)
if err != nil {
t.Fatal(err)
}
})
tc.test(t, db)
})
}
}
func test_loginSources_Create(t *testing.T, db *loginSources) {
// Create first login source with name "GitHub"
source, err := db.Create(CreateLoginSourceOpts{
Type: LoginGitHub,
Name: "GitHub",
Activated: true,
Default: false,
Config: &GitHubConfig{
APIEndpoint: "https://api.github.com",
},
})
if err != nil {
t.Fatal(err)
}
// Get it back and check the Created field
source, err = db.GetByID(source.ID)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, gorm.NowFunc().Format(time.RFC3339), source.Created.Format(time.RFC3339))
assert.Equal(t, gorm.NowFunc().Format(time.RFC3339), source.Updated.Format(time.RFC3339))
// Try create second login source with same name should fail
_, err = db.Create(CreateLoginSourceOpts{Name: source.Name})
expErr := ErrLoginSourceAlreadyExist{args: errutil.Args{"name": source.Name}}
assert.Equal(t, expErr, err)
}
func test_loginSources_Count(t *testing.T, db *loginSources) {
// Create two login sources, one in database and one as source file.
_, err := db.Create(CreateLoginSourceOpts{
Type: LoginGitHub,
Name: "GitHub",
Activated: true,
Default: false,
Config: &GitHubConfig{
APIEndpoint: "https://api.github.com",
},
})
if err != nil {
t.Fatal(err)
}
setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
MockLen: func() int {
return 2
},
})
assert.Equal(t, int64(3), db.Count())
}
func test_loginSources_DeleteByID(t *testing.T, db *loginSources) {
t.Run("delete but in used", func(t *testing.T) {
source, err := db.Create(CreateLoginSourceOpts{
Type: LoginGitHub,
Name: "GitHub",
Activated: true,
Default: false,
Config: &GitHubConfig{
APIEndpoint: "https://api.github.com",
},
})
if err != nil {
t.Fatal(err)
}
// Create a user that uses this login source
user := &User{
LoginSource: source.ID,
}
err = db.DB.Create(user).Error
if err != nil {
t.Fatal(err)
}
// Delete the login source will result in error
err = db.DeleteByID(source.ID)
expErr := ErrLoginSourceInUse{args: errutil.Args{"id": source.ID}}
assert.Equal(t, expErr, err)
})
setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
MockGetByID: func(id int64) (*LoginSource, error) {
return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
},
})
// Create a login source with name "GitHub2"
source, err := db.Create(CreateLoginSourceOpts{
Type: LoginGitHub,
Name: "GitHub2",
Activated: true,
Default: false,
Config: &GitHubConfig{
APIEndpoint: "https://api.github.com",
},
})
if err != nil {
t.Fatal(err)
}
// Delete a non-existent ID is noop
err = db.DeleteByID(9999)
if err != nil {
t.Fatal(err)
}
// We should be able to get it back
_, err = db.GetByID(source.ID)
if err != nil {
t.Fatal(err)
}
// Now delete this login source with ID
err = db.DeleteByID(source.ID)
if err != nil {
t.Fatal(err)
}
// We should get token not found error
_, err = db.GetByID(source.ID)
expErr := ErrLoginSourceNotExist{args: errutil.Args{"id": source.ID}}
assert.Equal(t, expErr, err)
}
func test_loginSources_GetByID(t *testing.T, db *loginSources) {
setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
MockGetByID: func(id int64) (*LoginSource, error) {
if id != 101 {
return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
}
return &LoginSource{ID: id}, nil
},
})
expConfig := &GitHubConfig{
APIEndpoint: "https://api.github.com",
}
// Create a login source with name "GitHub"
source, err := db.Create(CreateLoginSourceOpts{
Type: LoginGitHub,
Name: "GitHub",
Activated: true,
Default: false,
Config: expConfig,
})
if err != nil {
t.Fatal(err)
}
// Get the one in the database and test the read/write hooks
source, err = db.GetByID(source.ID)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, expConfig, source.Config)
// Get the one in source file store
_, err = db.GetByID(101)
if err != nil {
t.Fatal(err)
}
}
func test_loginSources_List(t *testing.T, db *loginSources) {
setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
MockList: func(opts ListLoginSourceOpts) []*LoginSource {
if opts.OnlyActivated {
return []*LoginSource{
{ID: 1},
}
}
return []*LoginSource{
{ID: 1},
{ID: 2},
}
},
})
// Create two login sources in database, one activated and the other one not
_, err := db.Create(CreateLoginSourceOpts{
Type: LoginPAM,
Name: "PAM",
Config: &PAMConfig{
ServiceName: "PAM",
},
})
if err != nil {
t.Fatal(err)
}
_, err = db.Create(CreateLoginSourceOpts{
Type: LoginGitHub,
Name: "GitHub",
Activated: true,
Config: &GitHubConfig{
APIEndpoint: "https://api.github.com",
},
})
if err != nil {
t.Fatal(err)
}
// List all login sources
sources, err := db.List(ListLoginSourceOpts{})
if err != nil {
t.Fatal(err)
}
assert.Equal(t, 4, len(sources), "number of sources")
// Only list activated login sources
sources, err = db.List(ListLoginSourceOpts{OnlyActivated: true})
if err != nil {
t.Fatal(err)
}
assert.Equal(t, 2, len(sources), "number of sources")
}
func test_loginSources_ResetNonDefault(t *testing.T, db *loginSources) {
setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
MockList: func(opts ListLoginSourceOpts) []*LoginSource {
return []*LoginSource{
{
File: &mockLoginSourceFileStore{
MockSetGeneral: func(name, value string) {
assert.Equal(t, "is_default", name)
assert.Equal(t, "false", value)
},
MockSave: func() error {
return nil
},
},
},
}
},
MockUpdate: func(source *LoginSource) {},
})
// Create two login sources both have default on
source1, err := db.Create(CreateLoginSourceOpts{
Type: LoginPAM,
Name: "PAM",
Default: true,
Config: &PAMConfig{
ServiceName: "PAM",
},
})
if err != nil {
t.Fatal(err)
}
source2, err := db.Create(CreateLoginSourceOpts{
Type: LoginGitHub,
Name: "GitHub",
Activated: true,
Default: true,
Config: &GitHubConfig{
APIEndpoint: "https://api.github.com",
},
})
if err != nil {
t.Fatal(err)
}
// Set source 1 as default
err = db.ResetNonDefault(source1)
if err != nil {
t.Fatal(err)
}
// Verify the default state
source1, err = db.GetByID(source1.ID)
if err != nil {
t.Fatal(err)
}
assert.True(t, source1.IsDefault)
source2, err = db.GetByID(source2.ID)
if err != nil {
t.Fatal(err)
}
assert.False(t, source2.IsDefault)
}
func test_loginSources_Save(t *testing.T, db *loginSources) {
t.Run("save to database", func(t *testing.T) {
// Create a login source with name "GitHub"
source, err := db.Create(CreateLoginSourceOpts{
Type: LoginGitHub,
Name: "GitHub",
Activated: true,
Default: false,
Config: &GitHubConfig{
APIEndpoint: "https://api.github.com",
},
})
if err != nil {
t.Fatal(err)
}
source.IsActived = false
source.Config = &GitHubConfig{
APIEndpoint: "https://api2.github.com",
}
err = db.Save(source)
if err != nil {
t.Fatal(err)
}
source, err = db.GetByID(source.ID)
if err != nil {
t.Fatal(err)
}
assert.False(t, source.IsActived)
assert.Equal(t, "https://api2.github.com", source.GitHub().APIEndpoint)
})
t.Run("save to file", func(t *testing.T) {
calledSave := false
source := &LoginSource{
File: &mockLoginSourceFileStore{
MockSetGeneral: func(name, value string) {},
MockSetConfig: func(cfg interface{}) error { return nil },
MockSave: func() error {
calledSave = true
return nil
},
},
}
err := db.Save(source)
if err != nil {
t.Fatal(err)
}
assert.True(t, calledSave)
})
}

View File

@ -41,7 +41,8 @@ func TestMain(m *testing.M) {
os.Exit(m.Run())
}
func deleteTables(db *gorm.DB, tables ...interface{}) error {
// clearTables removes all rows from given tables.
func clearTables(db *gorm.DB, tables ...interface{}) error {
for _, t := range tables {
err := db.Delete(t).Error
if err != nil {

View File

@ -78,6 +78,59 @@ func SetMockLFSStore(t *testing.T, mock LFSStore) {
})
}
var _ loginSourceFilesStore = (*mockLoginSourceFilesStore)(nil)
type mockLoginSourceFilesStore struct {
MockGetByID func(id int64) (*LoginSource, error)
MockLen func() int
MockList func(opts ListLoginSourceOpts) []*LoginSource
MockUpdate func(source *LoginSource)
}
func (m *mockLoginSourceFilesStore) GetByID(id int64) (*LoginSource, error) {
return m.MockGetByID(id)
}
func (m *mockLoginSourceFilesStore) Len() int {
return m.MockLen()
}
func (m *mockLoginSourceFilesStore) List(opts ListLoginSourceOpts) []*LoginSource {
return m.MockList(opts)
}
func (m *mockLoginSourceFilesStore) Update(source *LoginSource) {
m.MockUpdate(source)
}
func setMockLoginSourceFilesStore(t *testing.T, db *loginSources, mock loginSourceFilesStore) {
before := db.files
db.files = mock
t.Cleanup(func() {
db.files = before
})
}
var _ loginSourceFileStore = (*mockLoginSourceFileStore)(nil)
type mockLoginSourceFileStore struct {
MockSetGeneral func(name, value string)
MockSetConfig func(cfg interface{}) error
MockSave func() error
}
func (m *mockLoginSourceFileStore) SetGeneral(name, value string) {
m.MockSetGeneral(name, value)
}
func (m *mockLoginSourceFileStore) SetConfig(cfg interface{}) error {
return m.MockSetConfig(cfg)
}
func (m *mockLoginSourceFileStore) Save() error {
return m.MockSave()
}
var _ PermsStore = (*MockPermsStore)(nil)
type MockPermsStore struct {

View File

@ -53,7 +53,7 @@ func init() {
new(Watch), new(Star), new(Follow), new(Action),
new(Issue), new(PullRequest), new(Comment), new(Attachment), new(IssueUser),
new(Label), new(IssueLabel), new(Milestone),
new(Mirror), new(Release), new(LoginSource), new(Webhook), new(HookTask),
new(Mirror), new(Release), new(Webhook), new(HookTask),
new(ProtectBranch), new(ProtectBranchWhitelist),
new(Team), new(OrgUser), new(TeamUser), new(TeamRepo),
new(Notice), new(EmailAddress))
@ -200,7 +200,7 @@ func GetStatistic() (stats Statistic) {
stats.Counter.Follow, _ = x.Count(new(Follow))
stats.Counter.Mirror, _ = x.Count(new(Mirror))
stats.Counter.Release, _ = x.Count(new(Release))
stats.Counter.LoginSource = CountLoginSources()
stats.Counter.LoginSource = LoginSources.Count()
stats.Counter.Webhook, _ = x.Count(new(Webhook))
stats.Counter.Milestone, _ = x.Count(new(Milestone))
stats.Counter.Label, _ = x.Count(new(Label))
@ -295,13 +295,13 @@ func ImportDatabase(dirPath string, verbose bool) (err error) {
tp := LoginType(com.StrTo(com.ToStr(meta["Type"])).MustInt64())
switch tp {
case LoginLDAP, LoginDLDAP:
bean.Cfg = new(LDAPConfig)
bean.Config = new(LDAPConfig)
case LoginSMTP:
bean.Cfg = new(SMTPConfig)
bean.Config = new(SMTPConfig)
case LoginPAM:
bean.Cfg = new(PAMConfig)
bean.Config = new(PAMConfig)
case LoginGitHub:
bean.Cfg = new(GitHubConfig)
bean.Config = new(GitHubConfig)
default:
return fmt.Errorf("unrecognized login source type:: %v", tp)
}

View File

@ -17,8 +17,9 @@ func Test_perms(t *testing.T) {
t.Parallel()
tables := []interface{}{new(Access)}
db := &perms{
DB: initTestDB(t, "perms", new(Access)),
DB: initTestDB(t, "perms", tables...),
}
for _, tc := range []struct {
@ -31,7 +32,7 @@ func Test_perms(t *testing.T) {
} {
t.Run(tc.name, func(t *testing.T) {
t.Cleanup(func() {
err := deleteTables(db.DB, new(Access))
err := clearTables(db.DB, tables...)
if err != nil {
t.Fatal(err)
}

View File

@ -48,33 +48,33 @@ const (
// User represents the object of individual and member of organization.
type User struct {
ID int64
LowerName string `xorm:"UNIQUE NOT NULL"`
Name string `xorm:"UNIQUE NOT NULL"`
LowerName string `xorm:"UNIQUE NOT NULL" gorm:"UNIQUE"`
Name string `xorm:"UNIQUE NOT NULL" gorm:"NOT NULL"`
FullName string
// Email is the primary email address (to be used for communication)
Email string `xorm:"NOT NULL"`
Passwd string `xorm:"NOT NULL"`
Email string `xorm:"NOT NULL" gorm:"NOT NULL"`
Passwd string `xorm:"NOT NULL" gorm:"NOT NULL"`
LoginType LoginType
LoginSource int64 `xorm:"NOT NULL DEFAULT 0"`
LoginSource int64 `xorm:"NOT NULL DEFAULT 0" gorm:"NOT NULL;DEFAULT:0"`
LoginName string
Type UserType
OwnedOrgs []*User `xorm:"-" json:"-"`
Orgs []*User `xorm:"-" json:"-"`
Repos []*Repository `xorm:"-" json:"-"`
OwnedOrgs []*User `xorm:"-" gorm:"-" json:"-"`
Orgs []*User `xorm:"-" gorm:"-" json:"-"`
Repos []*Repository `xorm:"-" gorm:"-" json:"-"`
Location string
Website string
Rands string `xorm:"VARCHAR(10)"`
Salt string `xorm:"VARCHAR(10)"`
Rands string `xorm:"VARCHAR(10)" gorm:"TYPE:VARCHAR(10)"`
Salt string `xorm:"VARCHAR(10)" gorm:"TYPE:VARCHAR(10)"`
Created time.Time `xorm:"-" json:"-"`
Created time.Time `xorm:"-" gorm:"-" json:"-"`
CreatedUnix int64
Updated time.Time `xorm:"-" json:"-"`
Updated time.Time `xorm:"-" gorm:"-" json:"-"`
UpdatedUnix int64
// Remember visibility choice for convenience, true for private
LastRepoVisibility bool
// Maximum repository creation limit, -1 means use gloabl default
MaxRepoCreation int `xorm:"NOT NULL DEFAULT -1"`
// Maximum repository creation limit, -1 means use global default
MaxRepoCreation int `xorm:"NOT NULL DEFAULT -1" gorm:"NOT NULL;DEFAULT:-1"`
// Permissions
IsActive bool // Activate primary email
@ -84,13 +84,13 @@ type User struct {
ProhibitLogin bool
// Avatar
Avatar string `xorm:"VARCHAR(2048) NOT NULL"`
AvatarEmail string `xorm:"NOT NULL"`
Avatar string `xorm:"VARCHAR(2048) NOT NULL" gorm:"TYPE:VARCHAR(2048);NOT NULL"`
AvatarEmail string `xorm:"NOT NULL" gorm:"NOT NULL"`
UseCustomAvatar bool
// Counters
NumFollowers int
NumFollowing int `xorm:"NOT NULL DEFAULT 0"`
NumFollowing int `xorm:"NOT NULL DEFAULT 0" gorm:"NOT NULL;DEFAULT:0"`
NumStars int
NumRepos int
@ -98,8 +98,8 @@ type User struct {
Description string
NumTeams int
NumMembers int
Teams []*Team `xorm:"-" json:"-"`
Members []*User `xorm:"-" json:"-"`
Teams []*Team `xorm:"-" gorm:"-" json:"-"`
Members []*User `xorm:"-" gorm:"-" json:"-"`
}
func (u *User) BeforeInsert() {

View File

@ -29,6 +29,8 @@ func (w *Writer) Print(v ...interface{}) {
fmt.Fprintf(w.Writer, "[sql] [%s] [%s] %s %v (%d rows affected)", v[1:]...)
case "log":
fmt.Fprintf(w.Writer, "[log] [%s] %s", v[1:]...)
case "error":
fmt.Fprintf(w.Writer, "[err] [%s] %s", v[1:]...)
default:
fmt.Fprint(w.Writer, v...)
}

View File

@ -41,6 +41,11 @@ func TestWriter_Print(t *testing.T) {
vs: []interface{}{"log", "writer.go:65", "something"},
expOutput: "[log] [writer.go:65] something",
},
{
name: "error",
vs: []interface{}{"error", "writer.go:65", "something bad"},
expOutput: "[err] [writer.go:65] something bad",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {

View File

@ -17,6 +17,16 @@ func IsFile(path string) bool {
return !f.IsDir()
}
// IsDir returns true if given path is a directory, and returns false when it's
// a file or does not exist.
func IsDir(dir string) bool {
f, e := os.Stat(dir)
if e != nil {
return false
}
return f.IsDir()
}
// IsExist returns true if a file or directory exists.
func IsExist(path string) bool {
_, err := os.Stat(path)

View File

@ -33,6 +33,29 @@ func TestIsFile(t *testing.T) {
}
}
func TestIsDir(t *testing.T) {
tests := []struct {
path string
expVal bool
}{
{
path: "osutil.go",
expVal: false,
}, {
path: "../osutil",
expVal: true,
}, {
path: "not_found",
expVal: false,
},
}
for _, test := range tests {
t.Run("", func(t *testing.T) {
assert.Equal(t, test.expVal, IsDir(test.path))
})
}
}
func TestIsExist(t *testing.T) {
tests := []struct {
path string

View File

@ -11,7 +11,6 @@ import (
"github.com/unknwon/com"
log "unknwon.dev/clog/v2"
"xorm.io/core"
"gogs.io/gogs/internal/auth/ldap"
"gogs.io/gogs/internal/conf"
@ -32,13 +31,13 @@ func Authentications(c *context.Context) {
c.PageIs("AdminAuthentications")
var err error
c.Data["Sources"], err = db.ListLoginSources()
c.Data["Sources"], err = db.LoginSources.List(db.ListLoginSourceOpts{})
if err != nil {
c.Error(err, "list login sources")
return
}
c.Data["Total"] = db.CountLoginSources()
c.Data["Total"] = db.LoginSources.Count()
c.Success(AUTHS)
}
@ -56,9 +55,9 @@ var (
{db.LoginNames[db.LoginGitHub], db.LoginGitHub},
}
securityProtocols = []dropdownItem{
{db.SecurityProtocolNames[ldap.SECURITY_PROTOCOL_UNENCRYPTED], ldap.SECURITY_PROTOCOL_UNENCRYPTED},
{db.SecurityProtocolNames[ldap.SECURITY_PROTOCOL_LDAPS], ldap.SECURITY_PROTOCOL_LDAPS},
{db.SecurityProtocolNames[ldap.SECURITY_PROTOCOL_START_TLS], ldap.SECURITY_PROTOCOL_START_TLS},
{db.SecurityProtocolNames[ldap.SecurityProtocolUnencrypted], ldap.SecurityProtocolUnencrypted},
{db.SecurityProtocolNames[ldap.SecurityProtocolLDAPS], ldap.SecurityProtocolLDAPS},
{db.SecurityProtocolNames[ldap.SecurityProtocolStartTLS], ldap.SecurityProtocolStartTLS},
}
)
@ -69,7 +68,7 @@ func NewAuthSource(c *context.Context) {
c.Data["type"] = db.LoginLDAP
c.Data["CurrentTypeName"] = db.LoginNames[db.LoginLDAP]
c.Data["CurrentSecurityProtocol"] = db.SecurityProtocolNames[ldap.SECURITY_PROTOCOL_UNENCRYPTED]
c.Data["CurrentSecurityProtocol"] = db.SecurityProtocolNames[ldap.SecurityProtocolUnencrypted]
c.Data["smtp_auth"] = "PLAIN"
c.Data["is_active"] = true
c.Data["is_default"] = true
@ -81,7 +80,7 @@ func NewAuthSource(c *context.Context) {
func parseLDAPConfig(f form.Authentication) *db.LDAPConfig {
return &db.LDAPConfig{
Source: &ldap.Source{
Source: ldap.Source{
Host: f.Host,
Port: f.Port,
SecurityProtocol: ldap.SecurityProtocol(f.SecurityProtocol),
@ -129,11 +128,11 @@ func NewAuthSourcePost(c *context.Context, f form.Authentication) {
c.Data["SMTPAuths"] = db.SMTPAuths
hasTLS := false
var config core.Conversion
var config interface{}
switch db.LoginType(f.Type) {
case db.LoginLDAP, db.LoginDLDAP:
config = parseLDAPConfig(f)
hasTLS = ldap.SecurityProtocol(f.SecurityProtocol) > ldap.SECURITY_PROTOCOL_UNENCRYPTED
hasTLS = ldap.SecurityProtocol(f.SecurityProtocol) > ldap.SecurityProtocolUnencrypted
case db.LoginSMTP:
config = parseSMTPConfig(f)
hasTLS = true
@ -156,22 +155,31 @@ func NewAuthSourcePost(c *context.Context, f form.Authentication) {
return
}
if err := db.CreateLoginSource(&db.LoginSource{
source, err := db.LoginSources.Create(db.CreateLoginSourceOpts{
Type: db.LoginType(f.Type),
Name: f.Name,
IsActived: f.IsActive,
IsDefault: f.IsDefault,
Cfg: config,
}); err != nil {
Activated: f.IsActive,
Default: f.IsDefault,
Config: config,
})
if err != nil {
if db.IsErrLoginSourceAlreadyExist(err) {
c.FormErr("Name")
c.RenderWithErr(c.Tr("admin.auths.login_source_exist", err.(db.ErrLoginSourceAlreadyExist).Name), AUTH_NEW, f)
c.RenderWithErr(c.Tr("admin.auths.login_source_exist", f.Name), AUTH_NEW, f)
} else {
c.Error(err, "create login source")
}
return
}
if source.IsDefault {
err = db.LoginSources.ResetNonDefault(source)
if err != nil {
c.Error(err, "reset non-default login sources")
return
}
}
log.Trace("Authentication created by admin(%s): %s", c.User.Name, f.Name)
c.Flash.Success(c.Tr("admin.auths.new_success", f.Name))
@ -217,7 +225,7 @@ func EditAuthSourcePost(c *context.Context, f form.Authentication) {
return
}
var config core.Conversion
var config interface{}
switch db.LoginType(f.Type) {
case db.LoginLDAP, db.LoginDLDAP:
config = parseLDAPConfig(f)
@ -239,12 +247,20 @@ func EditAuthSourcePost(c *context.Context, f form.Authentication) {
source.Name = f.Name
source.IsActived = f.IsActive
source.IsDefault = f.IsDefault
source.Cfg = config
if err := db.UpdateLoginSource(source); err != nil {
source.Config = config
if err := db.LoginSources.Save(source); err != nil {
c.Error(err, "update login source")
return
}
if source.IsDefault {
err = db.LoginSources.ResetNonDefault(source)
if err != nil {
c.Error(err, "reset non-default login sources")
return
}
}
log.Trace("Authentication changed by admin '%s': %d", c.User.Name, source.ID)
c.Flash.Success(c.Tr("admin.auths.update_success"))
@ -252,13 +268,8 @@ func EditAuthSourcePost(c *context.Context, f form.Authentication) {
}
func DeleteAuthSource(c *context.Context) {
source, err := db.LoginSources.GetByID(c.ParamsInt64(":authid"))
if err != nil {
c.Error(err, "get login source by ID")
return
}
if err = db.DeleteSource(source); err != nil {
id := c.ParamsInt64(":authid")
if err := db.LoginSources.DeleteByID(id); err != nil {
if db.IsErrLoginSourceInUse(err) {
c.Flash.Error(c.Tr("admin.auths.still_in_used"))
} else {
@ -269,7 +280,7 @@ func DeleteAuthSource(c *context.Context) {
})
return
}
log.Trace("Authentication deleted by admin(%s): %d", c.User.Name, source.ID)
log.Trace("Authentication deleted by admin(%s): %d", c.User.Name, id)
c.Flash.Success(c.Tr("admin.auths.deletion_success"))
c.JSONSuccess(map[string]interface{}{

View File

@ -46,7 +46,7 @@ func NewUser(c *context.Context) {
c.Data["login_type"] = "0-0"
sources, err := db.ListLoginSources()
sources, err := db.LoginSources.List(db.ListLoginSourceOpts{})
if err != nil {
c.Error(err, "list login sources")
return
@ -62,7 +62,7 @@ func NewUserPost(c *context.Context, f form.AdminCrateUser) {
c.Data["PageIsAdmin"] = true
c.Data["PageIsAdminUsers"] = true
sources, err := db.ListLoginSources()
sources, err := db.LoginSources.List(db.ListLoginSourceOpts{})
if err != nil {
c.Error(err, "list login sources")
return
@ -141,7 +141,7 @@ func prepareUserInfo(c *context.Context) *db.User {
c.Data["LoginSource"] = &db.LoginSource{}
}
sources, err := db.ListLoginSources()
sources, err := db.LoginSources.List(db.ListLoginSourceOpts{})
if err != nil {
c.Error(err, "list login sources")
return nil

View File

@ -13,7 +13,6 @@ import (
"gogs.io/gogs/internal/conf"
"gogs.io/gogs/internal/context"
"gogs.io/gogs/internal/db"
"gogs.io/gogs/internal/db/errors"
"gogs.io/gogs/internal/email"
"gogs.io/gogs/internal/route/api/v1/user"
)
@ -25,7 +24,7 @@ func parseLoginSource(c *context.APIContext, u *db.User, sourceID int64, loginNa
source, err := db.LoginSources.GetByID(sourceID)
if err != nil {
if errors.IsLoginSourceNotExist(err) {
if db.IsErrLoginSourceNotExist(err) {
c.ErrorStatus(http.StatusUnprocessableEntity, err)
} else {
c.Error(err, "get login source by ID")

View File

@ -76,7 +76,6 @@ func GlobalInit(customConf string) error {
}
db.HasEngine = true
db.LoadAuthSources()
db.LoadRepoConfig()
db.NewRepoContext()

View File

@ -101,7 +101,7 @@ func Login(c *context.Context) {
}
// Display normal login page
loginSources, err := db.ActivatedLoginSources()
loginSources, err := db.LoginSources.List(db.ListLoginSourceOpts{OnlyActivated: true})
if err != nil {
c.Error(err, "list activated login sources")
return
@ -148,7 +148,7 @@ func afterLogin(c *context.Context, u *db.User, remember bool) {
func LoginPost(c *context.Context, f form.SignIn) {
c.Title("sign_in")
loginSources, err := db.ActivatedLoginSources()
loginSources, err := db.LoginSources.List(db.ListLoginSourceOpts{OnlyActivated: true})
if err != nil {
c.Error(err, "list activated login sources")
return

View File

@ -176,7 +176,7 @@
<input id="github_api_endpoint" name="github_api_endpoint" value="{{$cfg.APIEndpoint}}" placeholder="e.g. https://api.github.com/" required>
</div>
{{end}}
<div class="inline field {{if not .Source.IsSMTP}}hide{{end}}">
<div class="ui checkbox">
<label><strong>{{.i18n.Tr "admin.auths.enable_tls"}}</strong></label>
@ -203,7 +203,7 @@
</div>
<div class="field">
<button class="ui green button">{{.i18n.Tr "admin.auths.update"}}</button>
{{if not .Source.LocalFile}}
{{if not .Source.File}}
<div class="ui red button delete-button" data-url="{{$.Link}}/delete" data-id="{{.Source.ID}}">{{.i18n.Tr "admin.auths.delete"}}</div>
{{end}}
</div>