mirror of https://github.com/gogs/gogs.git
login_source: migrate to GORM and add tests (#6090)
* Use GORM in all write paths * Migrate to GORM * Fix lint errors * Use GORM to init table * dbutil: make writer detect error * Add more tests * Rename to clearTables * db: finish adding tests * osutil: add tests * Fix load source files pathpull/6092/head
parent
76bb647d24
commit
41f56ad05d
File diff suppressed because one or more lines are too long
|
@ -19,9 +19,9 @@ type SecurityProtocol int
|
|||
|
||||
// Note: new type must be added at the end of list to maintain compatibility.
|
||||
const (
|
||||
SECURITY_PROTOCOL_UNENCRYPTED SecurityProtocol = iota
|
||||
SECURITY_PROTOCOL_LDAPS
|
||||
SECURITY_PROTOCOL_START_TLS
|
||||
SecurityProtocolUnencrypted SecurityProtocol = iota
|
||||
SecurityProtocolLDAPS
|
||||
SecurityProtocolStartTLS
|
||||
)
|
||||
|
||||
// Basic LDAP authentication service
|
||||
|
@ -144,7 +144,7 @@ func dial(ls *Source) (*ldap.Conn, error) {
|
|||
ServerName: ls.Host,
|
||||
InsecureSkipVerify: ls.SkipVerify,
|
||||
}
|
||||
if ls.SecurityProtocol == SECURITY_PROTOCOL_LDAPS {
|
||||
if ls.SecurityProtocol == SecurityProtocolLDAPS {
|
||||
return ldap.DialTLS("tcp", fmt.Sprintf("%s:%d", ls.Host, ls.Port), tlsCfg)
|
||||
}
|
||||
|
||||
|
@ -153,7 +153,7 @@ func dial(ls *Source) (*ldap.Conn, error) {
|
|||
return nil, fmt.Errorf("Dial: %v", err)
|
||||
}
|
||||
|
||||
if ls.SecurityProtocol == SECURITY_PROTOCOL_START_TLS {
|
||||
if ls.SecurityProtocol == SecurityProtocolStartTLS {
|
||||
if err = conn.StartTLS(tlsCfg); err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("StartTLS: %v", err)
|
||||
|
|
|
@ -21,8 +21,9 @@ func Test_accessTokens(t *testing.T) {
|
|||
|
||||
t.Parallel()
|
||||
|
||||
tables := []interface{}{new(AccessToken)}
|
||||
db := &accessTokens{
|
||||
DB: initTestDB(t, "accessTokens", new(AccessToken)),
|
||||
DB: initTestDB(t, "accessTokens", tables...),
|
||||
}
|
||||
|
||||
for _, tc := range []struct {
|
||||
|
@ -37,7 +38,7 @@ func Test_accessTokens(t *testing.T) {
|
|||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
err := deleteTables(db.DB, new(AccessToken))
|
||||
err := clearTables(db.DB, tables...)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -78,14 +79,14 @@ func test_accessTokens_DeleteByID(t *testing.T, db *accessTokens) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// We should be able to get it back
|
||||
_, err = db.GetBySHA(token.Sha1)
|
||||
// Delete a token with mismatched user ID is noop
|
||||
err = db.DeleteByID(2, token.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Delete a token with mismatched user ID is noop
|
||||
err = db.DeleteByID(2, token.ID)
|
||||
// We should be able to get it back
|
||||
_, err = db.GetBySHA(token.Sha1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -124,7 +124,7 @@ func getLogWriter() (io.Writer, error) {
|
|||
|
||||
var tables = []interface{}{
|
||||
new(AccessToken),
|
||||
new(LFSObject),
|
||||
new(LFSObject), new(LoginSource),
|
||||
}
|
||||
|
||||
func Init() error {
|
||||
|
@ -167,9 +167,14 @@ func Init() error {
|
|||
return time.Now().UTC().Truncate(time.Microsecond)
|
||||
}
|
||||
|
||||
sourceFiles, err := loadLoginSourceFiles(filepath.Join(conf.CustomDir(), "conf", "auth.d"))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "load login source files")
|
||||
}
|
||||
|
||||
// Initialize stores, sorted in alphabetical order.
|
||||
AccessTokens = &accessTokens{DB: db}
|
||||
LoginSources = &loginSources{DB: db}
|
||||
LoginSources = &loginSources{DB: db, files: sourceFiles}
|
||||
LFS = &lfs{DB: db}
|
||||
Perms = &perms{DB: db}
|
||||
Repos = &repos{DB: db}
|
||||
|
|
|
@ -327,39 +327,6 @@ func (err ErrRepoFileAlreadyExist) Error() string {
|
|||
return fmt.Sprintf("repository file already exists [file_name: %s]", err.FileName)
|
||||
}
|
||||
|
||||
// .____ .__ _________
|
||||
// | | ____ ____ |__| ____ / _____/ ____ __ _________ ____ ____
|
||||
// | | / _ \ / ___\| |/ \ \_____ \ / _ \| | \_ __ \_/ ___\/ __ \
|
||||
// | |__( <_> ) /_/ > | | \ / ( <_> ) | /| | \/\ \__\ ___/
|
||||
// |_______ \____/\___ /|__|___| / /_______ /\____/|____/ |__| \___ >___ >
|
||||
// \/ /_____/ \/ \/ \/ \/
|
||||
|
||||
type ErrLoginSourceAlreadyExist struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
func IsErrLoginSourceAlreadyExist(err error) bool {
|
||||
_, ok := err.(ErrLoginSourceAlreadyExist)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrLoginSourceAlreadyExist) Error() string {
|
||||
return fmt.Sprintf("login source already exists [name: %s]", err.Name)
|
||||
}
|
||||
|
||||
type ErrLoginSourceInUse struct {
|
||||
ID int64
|
||||
}
|
||||
|
||||
func IsErrLoginSourceInUse(err error) bool {
|
||||
_, ok := err.(ErrLoginSourceInUse)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrLoginSourceInUse) Error() string {
|
||||
return fmt.Sprintf("login source is still used by some users [id: %d]", err.ID)
|
||||
}
|
||||
|
||||
// ___________
|
||||
// \__ ___/___ _____ _____
|
||||
// | |_/ __ \\__ \ / \
|
||||
|
|
|
@ -6,19 +6,6 @@ package errors
|
|||
|
||||
import "fmt"
|
||||
|
||||
type LoginSourceNotExist struct {
|
||||
ID int64
|
||||
}
|
||||
|
||||
func IsLoginSourceNotExist(err error) bool {
|
||||
_, ok := err.(LoginSourceNotExist)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err LoginSourceNotExist) Error() string {
|
||||
return fmt.Sprintf("login source does not exist [id: %d]", err.ID)
|
||||
}
|
||||
|
||||
type LoginSourceNotActivated struct {
|
||||
SourceID int64
|
||||
}
|
||||
|
@ -44,4 +31,3 @@ func IsInvalidLoginSourceType(err error) bool {
|
|||
func (err InvalidLoginSourceType) Error() string {
|
||||
return fmt.Sprintf("invalid login source type [type: %v]", err.Type)
|
||||
}
|
||||
|
||||
|
|
|
@ -22,8 +22,9 @@ func Test_lfs(t *testing.T) {
|
|||
|
||||
t.Parallel()
|
||||
|
||||
tables := []interface{}{new(LFSObject)}
|
||||
db := &lfs{
|
||||
DB: initTestDB(t, "lfs", new(LFSObject)),
|
||||
DB: initTestDB(t, "lfs", tables...),
|
||||
}
|
||||
|
||||
for _, tc := range []struct {
|
||||
|
@ -36,7 +37,7 @@ func Test_lfs(t *testing.T) {
|
|||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
err := deleteTables(db.DB, new(LFSObject))
|
||||
err := clearTables(db.DB, tables...)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -10,30 +10,21 @@ import (
|
|||
"fmt"
|
||||
"net/smtp"
|
||||
"net/textproto"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-macaron/binding"
|
||||
"github.com/json-iterator/go"
|
||||
"github.com/unknwon/com"
|
||||
"gopkg.in/ini.v1"
|
||||
log "unknwon.dev/clog/v2"
|
||||
"xorm.io/core"
|
||||
"xorm.io/xorm"
|
||||
|
||||
"gogs.io/gogs/internal/auth/github"
|
||||
"gogs.io/gogs/internal/auth/ldap"
|
||||
"gogs.io/gogs/internal/auth/pam"
|
||||
"gogs.io/gogs/internal/conf"
|
||||
"gogs.io/gogs/internal/db/errors"
|
||||
)
|
||||
|
||||
type LoginType int
|
||||
|
||||
// Note: new type must append to the end of list to maintain compatibility.
|
||||
// TODO: Move to authutil.
|
||||
const (
|
||||
LoginNotype LoginType = iota
|
||||
LoginPlain // 1
|
||||
|
@ -52,497 +43,24 @@ var LoginNames = map[LoginType]string{
|
|||
LoginGitHub: "GitHub",
|
||||
}
|
||||
|
||||
var SecurityProtocolNames = map[ldap.SecurityProtocol]string{
|
||||
ldap.SECURITY_PROTOCOL_UNENCRYPTED: "Unencrypted",
|
||||
ldap.SECURITY_PROTOCOL_LDAPS: "LDAPS",
|
||||
ldap.SECURITY_PROTOCOL_START_TLS: "StartTLS",
|
||||
}
|
||||
|
||||
// Ensure structs implemented interface.
|
||||
var (
|
||||
_ core.Conversion = &LDAPConfig{}
|
||||
_ core.Conversion = &SMTPConfig{}
|
||||
_ core.Conversion = &PAMConfig{}
|
||||
_ core.Conversion = &GitHubConfig{}
|
||||
)
|
||||
// ***********************
|
||||
// ----- LDAP config -----
|
||||
// ***********************
|
||||
|
||||
type LDAPConfig struct {
|
||||
*ldap.Source `ini:"config"`
|
||||
ldap.Source `ini:"config"`
|
||||
}
|
||||
|
||||
func (cfg *LDAPConfig) FromDB(bs []byte) error {
|
||||
return jsoniter.Unmarshal(bs, &cfg)
|
||||
}
|
||||
|
||||
func (cfg *LDAPConfig) ToDB() ([]byte, error) {
|
||||
return jsoniter.Marshal(cfg)
|
||||
var SecurityProtocolNames = map[ldap.SecurityProtocol]string{
|
||||
ldap.SecurityProtocolUnencrypted: "Unencrypted",
|
||||
ldap.SecurityProtocolLDAPS: "LDAPS",
|
||||
ldap.SecurityProtocolStartTLS: "StartTLS",
|
||||
}
|
||||
|
||||
func (cfg *LDAPConfig) SecurityProtocolName() string {
|
||||
return SecurityProtocolNames[cfg.SecurityProtocol]
|
||||
}
|
||||
|
||||
type SMTPConfig struct {
|
||||
Auth string
|
||||
Host string
|
||||
Port int
|
||||
AllowedDomains string `xorm:"TEXT"`
|
||||
TLS bool `ini:"tls"`
|
||||
SkipVerify bool
|
||||
}
|
||||
|
||||
func (cfg *SMTPConfig) FromDB(bs []byte) error {
|
||||
return jsoniter.Unmarshal(bs, cfg)
|
||||
}
|
||||
|
||||
func (cfg *SMTPConfig) ToDB() ([]byte, error) {
|
||||
return jsoniter.Marshal(cfg)
|
||||
}
|
||||
|
||||
type PAMConfig struct {
|
||||
ServiceName string // PAM service (e.g. system-auth)
|
||||
}
|
||||
|
||||
func (cfg *PAMConfig) FromDB(bs []byte) error {
|
||||
return jsoniter.Unmarshal(bs, &cfg)
|
||||
}
|
||||
|
||||
func (cfg *PAMConfig) ToDB() ([]byte, error) {
|
||||
return jsoniter.Marshal(cfg)
|
||||
}
|
||||
|
||||
type GitHubConfig struct {
|
||||
APIEndpoint string // GitHub service (e.g. https://api.github.com/)
|
||||
}
|
||||
|
||||
func (cfg *GitHubConfig) FromDB(bs []byte) error {
|
||||
return jsoniter.Unmarshal(bs, &cfg)
|
||||
}
|
||||
|
||||
func (cfg *GitHubConfig) ToDB() ([]byte, error) {
|
||||
return jsoniter.Marshal(cfg)
|
||||
}
|
||||
|
||||
// AuthSourceFile contains information of an authentication source file.
|
||||
type AuthSourceFile struct {
|
||||
abspath string
|
||||
file *ini.File
|
||||
}
|
||||
|
||||
// SetGeneral sets new value to the given key in the general (default) section.
|
||||
func (f *AuthSourceFile) SetGeneral(name, value string) {
|
||||
f.file.Section("").Key(name).SetValue(value)
|
||||
}
|
||||
|
||||
// SetConfig sets new values to the "config" section.
|
||||
func (f *AuthSourceFile) SetConfig(cfg core.Conversion) error {
|
||||
return f.file.Section("config").ReflectFrom(cfg)
|
||||
}
|
||||
|
||||
// Save writes updates into file system.
|
||||
func (f *AuthSourceFile) Save() error {
|
||||
return f.file.SaveTo(f.abspath)
|
||||
}
|
||||
|
||||
// LoginSource represents an external way for authorizing users.
|
||||
type LoginSource struct {
|
||||
ID int64
|
||||
Type LoginType
|
||||
Name string `xorm:"UNIQUE"`
|
||||
IsActived bool `xorm:"NOT NULL DEFAULT false"`
|
||||
IsDefault bool `xorm:"DEFAULT false"`
|
||||
Cfg core.Conversion `xorm:"TEXT" gorm:"COLUMN:remove-me-when-migrated-to-gorm"`
|
||||
RawCfg string `xorm:"-" gorm:"COLUMN:cfg"` // TODO: Remove me when migrated to GORM.
|
||||
|
||||
Created time.Time `xorm:"-" json:"-"`
|
||||
CreatedUnix int64
|
||||
Updated time.Time `xorm:"-" json:"-"`
|
||||
UpdatedUnix int64
|
||||
|
||||
LocalFile *AuthSourceFile `xorm:"-" json:"-"`
|
||||
}
|
||||
|
||||
func (s *LoginSource) BeforeInsert() {
|
||||
s.CreatedUnix = time.Now().Unix()
|
||||
s.UpdatedUnix = s.CreatedUnix
|
||||
}
|
||||
|
||||
func (s *LoginSource) BeforeUpdate() {
|
||||
s.UpdatedUnix = time.Now().Unix()
|
||||
}
|
||||
|
||||
// Cell2Int64 converts a xorm.Cell type to int64,
|
||||
// and handles possible irregular cases.
|
||||
func Cell2Int64(val xorm.Cell) int64 {
|
||||
switch (*val).(type) {
|
||||
case []uint8:
|
||||
log.Trace("Cell2Int64 ([]uint8): %v", *val)
|
||||
return com.StrTo(string((*val).([]uint8))).MustInt64()
|
||||
}
|
||||
return (*val).(int64)
|
||||
}
|
||||
|
||||
func (s *LoginSource) BeforeSet(colName string, val xorm.Cell) {
|
||||
switch colName {
|
||||
case "type":
|
||||
switch LoginType(Cell2Int64(val)) {
|
||||
case LoginLDAP, LoginDLDAP:
|
||||
s.Cfg = new(LDAPConfig)
|
||||
case LoginSMTP:
|
||||
s.Cfg = new(SMTPConfig)
|
||||
case LoginPAM:
|
||||
s.Cfg = new(PAMConfig)
|
||||
case LoginGitHub:
|
||||
s.Cfg = new(GitHubConfig)
|
||||
default:
|
||||
panic("unrecognized login source type: " + com.ToStr(*val))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *LoginSource) AfterSet(colName string, _ xorm.Cell) {
|
||||
switch colName {
|
||||
case "created_unix":
|
||||
s.Created = time.Unix(s.CreatedUnix, 0).Local()
|
||||
case "updated_unix":
|
||||
s.Updated = time.Unix(s.UpdatedUnix, 0).Local()
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: This is a GORM query hook.
|
||||
func (s *LoginSource) AfterFind() error {
|
||||
switch s.Type {
|
||||
case LoginLDAP, LoginDLDAP:
|
||||
s.Cfg = new(LDAPConfig)
|
||||
case LoginSMTP:
|
||||
s.Cfg = new(SMTPConfig)
|
||||
case LoginPAM:
|
||||
s.Cfg = new(PAMConfig)
|
||||
case LoginGitHub:
|
||||
s.Cfg = new(GitHubConfig)
|
||||
default:
|
||||
return fmt.Errorf("unrecognized login source type: %v", s.Type)
|
||||
}
|
||||
return jsoniter.UnmarshalFromString(s.RawCfg, s.Cfg)
|
||||
}
|
||||
|
||||
func (s *LoginSource) TypeName() string {
|
||||
return LoginNames[s.Type]
|
||||
}
|
||||
|
||||
func (s *LoginSource) IsLDAP() bool {
|
||||
return s.Type == LoginLDAP
|
||||
}
|
||||
|
||||
func (s *LoginSource) IsDLDAP() bool {
|
||||
return s.Type == LoginDLDAP
|
||||
}
|
||||
|
||||
func (s *LoginSource) IsSMTP() bool {
|
||||
return s.Type == LoginSMTP
|
||||
}
|
||||
|
||||
func (s *LoginSource) IsPAM() bool {
|
||||
return s.Type == LoginPAM
|
||||
}
|
||||
|
||||
func (s *LoginSource) IsGitHub() bool {
|
||||
return s.Type == LoginGitHub
|
||||
}
|
||||
|
||||
func (s *LoginSource) HasTLS() bool {
|
||||
return ((s.IsLDAP() || s.IsDLDAP()) &&
|
||||
s.LDAP().SecurityProtocol > ldap.SECURITY_PROTOCOL_UNENCRYPTED) ||
|
||||
s.IsSMTP()
|
||||
}
|
||||
|
||||
func (s *LoginSource) UseTLS() bool {
|
||||
switch s.Type {
|
||||
case LoginLDAP, LoginDLDAP:
|
||||
return s.LDAP().SecurityProtocol != ldap.SECURITY_PROTOCOL_UNENCRYPTED
|
||||
case LoginSMTP:
|
||||
return s.SMTP().TLS
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *LoginSource) SkipVerify() bool {
|
||||
switch s.Type {
|
||||
case LoginLDAP, LoginDLDAP:
|
||||
return s.LDAP().SkipVerify
|
||||
case LoginSMTP:
|
||||
return s.SMTP().SkipVerify
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *LoginSource) LDAP() *LDAPConfig {
|
||||
return s.Cfg.(*LDAPConfig)
|
||||
}
|
||||
|
||||
func (s *LoginSource) SMTP() *SMTPConfig {
|
||||
return s.Cfg.(*SMTPConfig)
|
||||
}
|
||||
|
||||
func (s *LoginSource) PAM() *PAMConfig {
|
||||
return s.Cfg.(*PAMConfig)
|
||||
}
|
||||
|
||||
func (s *LoginSource) GitHub() *GitHubConfig {
|
||||
return s.Cfg.(*GitHubConfig)
|
||||
}
|
||||
|
||||
func CreateLoginSource(source *LoginSource) error {
|
||||
has, err := x.Get(&LoginSource{Name: source.Name})
|
||||
if err != nil {
|
||||
return err
|
||||
} else if has {
|
||||
return ErrLoginSourceAlreadyExist{source.Name}
|
||||
}
|
||||
|
||||
_, err = x.Insert(source)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if source.IsDefault {
|
||||
return ResetNonDefaultLoginSources(source)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListLoginSources returns all login sources defined.
|
||||
func ListLoginSources() ([]*LoginSource, error) {
|
||||
sources := make([]*LoginSource, 0, 2)
|
||||
if err := x.Find(&sources); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return append(sources, localLoginSources.List()...), nil
|
||||
}
|
||||
|
||||
// ActivatedLoginSources returns login sources that are currently activated.
|
||||
func ActivatedLoginSources() ([]*LoginSource, error) {
|
||||
sources := make([]*LoginSource, 0, 2)
|
||||
if err := x.Where("is_actived = ?", true).Find(&sources); err != nil {
|
||||
return nil, fmt.Errorf("find activated login sources: %v", err)
|
||||
}
|
||||
return append(sources, localLoginSources.ActivatedList()...), nil
|
||||
}
|
||||
|
||||
// ResetNonDefaultLoginSources clean other default source flag
|
||||
func ResetNonDefaultLoginSources(source *LoginSource) error {
|
||||
// update changes to DB
|
||||
if _, err := x.NotIn("id", []int64{source.ID}).Cols("is_default").Update(&LoginSource{IsDefault: false}); err != nil {
|
||||
return err
|
||||
}
|
||||
// write changes to local authentications
|
||||
for i := range localLoginSources.sources {
|
||||
if localLoginSources.sources[i].LocalFile != nil && localLoginSources.sources[i].ID != source.ID {
|
||||
localLoginSources.sources[i].LocalFile.SetGeneral("is_default", "false")
|
||||
if err := localLoginSources.sources[i].LocalFile.SetConfig(source.Cfg); err != nil {
|
||||
return fmt.Errorf("LocalFile.SetConfig: %v", err)
|
||||
} else if err = localLoginSources.sources[i].LocalFile.Save(); err != nil {
|
||||
return fmt.Errorf("LocalFile.Save: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
// flush memory so that web page can show the same behaviors
|
||||
localLoginSources.UpdateLoginSource(source)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateLoginSource updates information of login source to database or local file.
|
||||
func UpdateLoginSource(source *LoginSource) error {
|
||||
if source.LocalFile == nil {
|
||||
if _, err := x.Id(source.ID).AllCols().Update(source); err != nil {
|
||||
return err
|
||||
} else {
|
||||
return ResetNonDefaultLoginSources(source)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
source.LocalFile.SetGeneral("name", source.Name)
|
||||
source.LocalFile.SetGeneral("is_activated", com.ToStr(source.IsActived))
|
||||
source.LocalFile.SetGeneral("is_default", com.ToStr(source.IsDefault))
|
||||
if err := source.LocalFile.SetConfig(source.Cfg); err != nil {
|
||||
return fmt.Errorf("LocalFile.SetConfig: %v", err)
|
||||
} else if err = source.LocalFile.Save(); err != nil {
|
||||
return fmt.Errorf("LocalFile.Save: %v", err)
|
||||
}
|
||||
return ResetNonDefaultLoginSources(source)
|
||||
}
|
||||
|
||||
func DeleteSource(source *LoginSource) error {
|
||||
count, err := x.Count(&User{LoginSource: source.ID})
|
||||
if err != nil {
|
||||
return err
|
||||
} else if count > 0 {
|
||||
return ErrLoginSourceInUse{source.ID}
|
||||
}
|
||||
_, err = x.Id(source.ID).Delete(new(LoginSource))
|
||||
return err
|
||||
}
|
||||
|
||||
// CountLoginSources returns total number of login sources.
|
||||
func CountLoginSources() int64 {
|
||||
count, _ := x.Count(new(LoginSource))
|
||||
return count + int64(localLoginSources.Len())
|
||||
}
|
||||
|
||||
// LocalLoginSources contains authentication sources configured and loaded from local files.
|
||||
// Calling its methods is thread-safe; otherwise, please maintain the mutex accordingly.
|
||||
type LocalLoginSources struct {
|
||||
sync.RWMutex
|
||||
sources []*LoginSource
|
||||
}
|
||||
|
||||
func (s *LocalLoginSources) Len() int {
|
||||
return len(s.sources)
|
||||
}
|
||||
|
||||
// List returns full clone of login sources.
|
||||
func (s *LocalLoginSources) List() []*LoginSource {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
|
||||
list := make([]*LoginSource, s.Len())
|
||||
for i := range s.sources {
|
||||
list[i] = &LoginSource{}
|
||||
*list[i] = *s.sources[i]
|
||||
}
|
||||
return list
|
||||
}
|
||||
|
||||
// ActivatedList returns clone of activated login sources.
|
||||
func (s *LocalLoginSources) ActivatedList() []*LoginSource {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
|
||||
list := make([]*LoginSource, 0, 2)
|
||||
for i := range s.sources {
|
||||
if !s.sources[i].IsActived {
|
||||
continue
|
||||
}
|
||||
source := &LoginSource{}
|
||||
*source = *s.sources[i]
|
||||
list = append(list, source)
|
||||
}
|
||||
return list
|
||||
}
|
||||
|
||||
// GetLoginSourceByID returns a clone of login source by given ID.
|
||||
func (s *LocalLoginSources) GetLoginSourceByID(id int64) (*LoginSource, error) {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
|
||||
for i := range s.sources {
|
||||
if s.sources[i].ID == id {
|
||||
source := &LoginSource{}
|
||||
*source = *s.sources[i]
|
||||
return source, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.LoginSourceNotExist{ID: id}
|
||||
}
|
||||
|
||||
// UpdateLoginSource updates in-memory copy of the authentication source.
|
||||
func (s *LocalLoginSources) UpdateLoginSource(source *LoginSource) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
source.Updated = time.Now()
|
||||
for i := range s.sources {
|
||||
if s.sources[i].ID == source.ID {
|
||||
*s.sources[i] = *source
|
||||
} else if source.IsDefault {
|
||||
s.sources[i].IsDefault = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var localLoginSources = &LocalLoginSources{}
|
||||
|
||||
// LoadAuthSources loads authentication sources from local files
|
||||
// and converts them into login sources.
|
||||
func LoadAuthSources() {
|
||||
authdPath := filepath.Join(conf.CustomDir(), "conf", "auth.d")
|
||||
if !com.IsDir(authdPath) {
|
||||
return
|
||||
}
|
||||
|
||||
paths, err := com.GetFileListBySuffix(authdPath, ".conf")
|
||||
if err != nil {
|
||||
log.Fatal("Failed to list authentication sources: %v", err)
|
||||
}
|
||||
|
||||
localLoginSources.sources = make([]*LoginSource, 0, len(paths))
|
||||
|
||||
for _, fpath := range paths {
|
||||
authSource, err := ini.Load(fpath)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to load authentication source: %v", err)
|
||||
}
|
||||
authSource.NameMapper = ini.TitleUnderscore
|
||||
|
||||
// Set general attributes
|
||||
s := authSource.Section("")
|
||||
loginSource := &LoginSource{
|
||||
ID: s.Key("id").MustInt64(),
|
||||
Name: s.Key("name").String(),
|
||||
IsActived: s.Key("is_activated").MustBool(),
|
||||
IsDefault: s.Key("is_default").MustBool(),
|
||||
LocalFile: &AuthSourceFile{
|
||||
abspath: fpath,
|
||||
file: authSource,
|
||||
},
|
||||
}
|
||||
|
||||
fi, err := os.Stat(fpath)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to load authentication source: %v", err)
|
||||
}
|
||||
loginSource.Updated = fi.ModTime()
|
||||
|
||||
// Parse authentication source file
|
||||
authType := s.Key("type").String()
|
||||
switch authType {
|
||||
case "ldap_bind_dn":
|
||||
loginSource.Type = LoginLDAP
|
||||
loginSource.Cfg = &LDAPConfig{}
|
||||
case "ldap_simple_auth":
|
||||
loginSource.Type = LoginDLDAP
|
||||
loginSource.Cfg = &LDAPConfig{}
|
||||
case "smtp":
|
||||
loginSource.Type = LoginSMTP
|
||||
loginSource.Cfg = &SMTPConfig{}
|
||||
case "pam":
|
||||
loginSource.Type = LoginPAM
|
||||
loginSource.Cfg = &PAMConfig{}
|
||||
case "github":
|
||||
loginSource.Type = LoginGitHub
|
||||
loginSource.Cfg = &GitHubConfig{}
|
||||
default:
|
||||
log.Fatal("Failed to load authentication source: unknown type '%s'", authType)
|
||||
}
|
||||
|
||||
if err = authSource.Section("config").MapTo(loginSource.Cfg); err != nil {
|
||||
log.Fatal("Failed to parse authentication source 'config': %v", err)
|
||||
}
|
||||
|
||||
localLoginSources.sources = append(localLoginSources.sources, loginSource)
|
||||
}
|
||||
}
|
||||
|
||||
// .____ ________ _____ __________
|
||||
// | | \______ \ / _ \\______ \
|
||||
// | | | | \ / /_\ \| ___/
|
||||
// | |___ | ` \/ | \ |
|
||||
// |_______ \/_______ /\____|__ /____|
|
||||
// \/ \/ \/
|
||||
|
||||
func composeFullName(firstname, surname, username string) string {
|
||||
switch {
|
||||
case len(firstname) == 0 && len(surname) == 0:
|
||||
|
@ -559,7 +77,7 @@ func composeFullName(firstname, surname, username string) string {
|
|||
// LoginViaLDAP queries if login/password is valid against the LDAP directory pool,
|
||||
// and create a local user if success when enabled.
|
||||
func LoginViaLDAP(login, password string, source *LoginSource, autoRegister bool) (*User, error) {
|
||||
username, fn, sn, mail, isAdmin, succeed := source.Cfg.(*LDAPConfig).SearchEntry(login, password, source.Type == LoginDLDAP)
|
||||
username, fn, sn, mail, isAdmin, succeed := source.Config.(*LDAPConfig).SearchEntry(login, password, source.Type == LoginDLDAP)
|
||||
if !succeed {
|
||||
// User not in LDAP, do nothing
|
||||
return nil, ErrUserNotExist{args: map[string]interface{}{"login": login}}
|
||||
|
@ -606,12 +124,18 @@ func LoginViaLDAP(login, password string, source *LoginSource, autoRegister bool
|
|||
return user, CreateUser(user)
|
||||
}
|
||||
|
||||
// _________ __________________________
|
||||
// / _____/ / \__ ___/\______ \
|
||||
// \_____ \ / \ / \| | | ___/
|
||||
// / \/ Y \ | | |
|
||||
// /_______ /\____|__ /____| |____|
|
||||
// \/ \/
|
||||
// ***********************
|
||||
// ----- SMTP config -----
|
||||
// ***********************
|
||||
|
||||
type SMTPConfig struct {
|
||||
Auth string
|
||||
Host string
|
||||
Port int
|
||||
AllowedDomains string
|
||||
TLS bool `ini:"tls"`
|
||||
SkipVerify bool
|
||||
}
|
||||
|
||||
type smtpLoginAuth struct {
|
||||
username, password string
|
||||
|
@ -634,11 +158,11 @@ func (auth *smtpLoginAuth) Next(fromServer []byte, more bool) ([]byte, error) {
|
|||
}
|
||||
|
||||
const (
|
||||
SMTP_PLAIN = "PLAIN"
|
||||
SMTP_LOGIN = "LOGIN"
|
||||
SMTPPlain = "PLAIN"
|
||||
SMTPLogin = "LOGIN"
|
||||
)
|
||||
|
||||
var SMTPAuths = []string{SMTP_PLAIN, SMTP_LOGIN}
|
||||
var SMTPAuths = []string{SMTPPlain, SMTPLogin}
|
||||
|
||||
func SMTPAuth(a smtp.Auth, cfg *SMTPConfig) error {
|
||||
c, err := smtp.Dial(fmt.Sprintf("%s:%d", cfg.Host, cfg.Port))
|
||||
|
@ -687,9 +211,9 @@ func LoginViaSMTP(login, password string, sourceID int64, cfg *SMTPConfig, autoR
|
|||
}
|
||||
|
||||
var auth smtp.Auth
|
||||
if cfg.Auth == SMTP_PLAIN {
|
||||
if cfg.Auth == SMTPPlain {
|
||||
auth = smtp.PlainAuth("", login, password, cfg.Host)
|
||||
} else if cfg.Auth == SMTP_LOGIN {
|
||||
} else if cfg.Auth == SMTPLogin {
|
||||
auth = &smtpLoginAuth{login, password}
|
||||
} else {
|
||||
return nil, errors.New("Unsupported SMTP authentication type")
|
||||
|
@ -729,12 +253,14 @@ func LoginViaSMTP(login, password string, sourceID int64, cfg *SMTPConfig, autoR
|
|||
return user, CreateUser(user)
|
||||
}
|
||||
|
||||
// __________ _____ _____
|
||||
// \______ \/ _ \ / \
|
||||
// | ___/ /_\ \ / \ / \
|
||||
// | | / | \/ Y \
|
||||
// |____| \____|__ /\____|__ /
|
||||
// \/ \/
|
||||
// **********************
|
||||
// ----- PAM config -----
|
||||
// **********************
|
||||
|
||||
type PAMConfig struct {
|
||||
// The name of the PAM service, e.g. system-auth.
|
||||
ServiceName string
|
||||
}
|
||||
|
||||
// LoginViaPAM queries if login/password is valid against the PAM,
|
||||
// and create a local user if success when enabled.
|
||||
|
@ -763,12 +289,14 @@ func LoginViaPAM(login, password string, sourceID int64, cfg *PAMConfig, autoReg
|
|||
return user, CreateUser(user)
|
||||
}
|
||||
|
||||
// ________.__ __ ___ ___ ___.
|
||||
// / _____/|__|/ |_ / | \ __ _\_ |__
|
||||
// / \ ___| \ __\/ ~ \ | \ __ \
|
||||
// \ \_\ \ || | \ Y / | / \_\ \
|
||||
// \______ /__||__| \___|_ /|____/|___ /
|
||||
// \/ \/ \/
|
||||
// *************************
|
||||
// ----- GitHub config -----
|
||||
// *************************
|
||||
|
||||
type GitHubConfig struct {
|
||||
// the GitHub service endpoint, e.g. https://api.github.com/.
|
||||
APIEndpoint string
|
||||
}
|
||||
|
||||
func LoginViaGitHub(login, password string, sourceID int64, cfg *GitHubConfig, autoRegister bool) (*User, error) {
|
||||
fullname, email, url, location, err := github.Authenticate(cfg.APIEndpoint, login, password)
|
||||
|
@ -807,11 +335,11 @@ func authenticateViaLoginSource(source *LoginSource, login, password string, aut
|
|||
case LoginLDAP, LoginDLDAP:
|
||||
return LoginViaLDAP(login, password, source, autoRegister)
|
||||
case LoginSMTP:
|
||||
return LoginViaSMTP(login, password, source.ID, source.Cfg.(*SMTPConfig), autoRegister)
|
||||
return LoginViaSMTP(login, password, source.ID, source.Config.(*SMTPConfig), autoRegister)
|
||||
case LoginPAM:
|
||||
return LoginViaPAM(login, password, source.ID, source.Cfg.(*PAMConfig), autoRegister)
|
||||
return LoginViaPAM(login, password, source.ID, source.Config.(*PAMConfig), autoRegister)
|
||||
case LoginGitHub:
|
||||
return LoginViaGitHub(login, password, source.ID, source.Cfg.(*GitHubConfig), autoRegister)
|
||||
return LoginViaGitHub(login, password, source.ID, source.Config.(*GitHubConfig), autoRegister)
|
||||
}
|
||||
|
||||
return nil, errors.InvalidLoginSourceType{Type: source.Type}
|
||||
|
|
|
@ -0,0 +1,212 @@
|
|||
// Copyright 2020 The Gogs Authors. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/pkg/errors"
|
||||
"gopkg.in/ini.v1"
|
||||
|
||||
"gogs.io/gogs/internal/errutil"
|
||||
"gogs.io/gogs/internal/osutil"
|
||||
)
|
||||
|
||||
// loginSourceFilesStore is the in-memory interface for login source files stored on file system.
|
||||
//
|
||||
// NOTE: All methods are sorted in alphabetical order.
|
||||
type loginSourceFilesStore interface {
|
||||
// GetByID returns a clone of login source by given ID.
|
||||
GetByID(id int64) (*LoginSource, error)
|
||||
// Len returns number of login sources.
|
||||
Len() int
|
||||
// List returns a list of login sources filtered by options.
|
||||
List(opts ListLoginSourceOpts) []*LoginSource
|
||||
// Update updates in-memory copy of the authentication source.
|
||||
Update(source *LoginSource)
|
||||
}
|
||||
|
||||
var _ loginSourceFilesStore = (*loginSourceFiles)(nil)
|
||||
|
||||
// loginSourceFiles contains authentication sources configured and loaded from local files.
|
||||
type loginSourceFiles struct {
|
||||
sync.RWMutex
|
||||
sources []*LoginSource
|
||||
}
|
||||
|
||||
var _ errutil.NotFound = (*ErrLoginSourceNotExist)(nil)
|
||||
|
||||
type ErrLoginSourceNotExist struct {
|
||||
args errutil.Args
|
||||
}
|
||||
|
||||
func IsErrLoginSourceNotExist(err error) bool {
|
||||
_, ok := err.(ErrLoginSourceNotExist)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrLoginSourceNotExist) Error() string {
|
||||
return fmt.Sprintf("login source does not exist: %v", err.args)
|
||||
}
|
||||
|
||||
func (ErrLoginSourceNotExist) NotFound() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *loginSourceFiles) GetByID(id int64) (*LoginSource, error) {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
|
||||
for _, source := range s.sources {
|
||||
if source.ID == id {
|
||||
return source, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
|
||||
}
|
||||
|
||||
func (s *loginSourceFiles) Len() int {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
return len(s.sources)
|
||||
}
|
||||
|
||||
func (s *loginSourceFiles) List(opts ListLoginSourceOpts) []*LoginSource {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
|
||||
list := make([]*LoginSource, 0, s.Len())
|
||||
for _, source := range s.sources {
|
||||
if opts.OnlyActivated && !source.IsActived {
|
||||
continue
|
||||
}
|
||||
|
||||
list = append(list, source)
|
||||
}
|
||||
return list
|
||||
}
|
||||
|
||||
func (s *loginSourceFiles) Update(source *LoginSource) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
source.Updated = gorm.NowFunc()
|
||||
for _, old := range s.sources {
|
||||
if old.ID == source.ID {
|
||||
*old = *source
|
||||
} else if source.IsDefault {
|
||||
old.IsDefault = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// loadLoginSourceFiles loads login sources from file system.
|
||||
func loadLoginSourceFiles(authdPath string) (loginSourceFilesStore, error) {
|
||||
if !osutil.IsDir(authdPath) {
|
||||
return &loginSourceFiles{}, nil
|
||||
}
|
||||
|
||||
store := &loginSourceFiles{}
|
||||
return store, filepath.Walk(authdPath, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if path == authdPath || !strings.HasSuffix(path, ".conf") {
|
||||
return nil
|
||||
} else if info.IsDir() {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
|
||||
authSource, err := ini.Load(path)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "load file")
|
||||
}
|
||||
authSource.NameMapper = ini.TitleUnderscore
|
||||
|
||||
// Set general attributes
|
||||
s := authSource.Section("")
|
||||
loginSource := &LoginSource{
|
||||
ID: s.Key("id").MustInt64(),
|
||||
Name: s.Key("name").String(),
|
||||
IsActived: s.Key("is_activated").MustBool(),
|
||||
IsDefault: s.Key("is_default").MustBool(),
|
||||
File: &loginSourceFile{
|
||||
path: path,
|
||||
file: authSource,
|
||||
},
|
||||
}
|
||||
|
||||
fi, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "stat file")
|
||||
}
|
||||
loginSource.Updated = fi.ModTime()
|
||||
|
||||
// Parse authentication source file
|
||||
authType := s.Key("type").String()
|
||||
switch authType {
|
||||
case "ldap_bind_dn":
|
||||
loginSource.Type = LoginLDAP
|
||||
loginSource.Config = &LDAPConfig{}
|
||||
case "ldap_simple_auth":
|
||||
loginSource.Type = LoginDLDAP
|
||||
loginSource.Config = &LDAPConfig{}
|
||||
case "smtp":
|
||||
loginSource.Type = LoginSMTP
|
||||
loginSource.Config = &SMTPConfig{}
|
||||
case "pam":
|
||||
loginSource.Type = LoginPAM
|
||||
loginSource.Config = &PAMConfig{}
|
||||
case "github":
|
||||
loginSource.Type = LoginGitHub
|
||||
loginSource.Config = &GitHubConfig{}
|
||||
default:
|
||||
return fmt.Errorf("unknown type %q", authType)
|
||||
}
|
||||
|
||||
if err = authSource.Section("config").MapTo(loginSource.Config); err != nil {
|
||||
return errors.Wrap(err, `map "config" section`)
|
||||
}
|
||||
|
||||
store.sources = append(store.sources, loginSource)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// loginSourceFileStore is the persistent interface for a login source file.
|
||||
type loginSourceFileStore interface {
|
||||
// SetGeneral sets new value to the given key in the general (default) section.
|
||||
SetGeneral(name, value string)
|
||||
// SetConfig sets new values to the "config" section.
|
||||
SetConfig(cfg interface{}) error
|
||||
// Save persists values to file system.
|
||||
Save() error
|
||||
}
|
||||
|
||||
var _ loginSourceFileStore = (*loginSourceFile)(nil)
|
||||
|
||||
type loginSourceFile struct {
|
||||
path string
|
||||
file *ini.File
|
||||
}
|
||||
|
||||
func (f *loginSourceFile) SetGeneral(name, value string) {
|
||||
f.file.Section("").Key(name).SetValue(value)
|
||||
}
|
||||
|
||||
func (f *loginSourceFile) SetConfig(cfg interface{}) error {
|
||||
return f.file.Section("config").ReflectFrom(cfg)
|
||||
}
|
||||
|
||||
func (f *loginSourceFile) Save() error {
|
||||
return f.file.SaveTo(f.path)
|
||||
}
|
|
@ -5,22 +5,242 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
jsoniter "github.com/json-iterator/go"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"gogs.io/gogs/internal/auth/ldap"
|
||||
"gogs.io/gogs/internal/errutil"
|
||||
)
|
||||
|
||||
// LoginSourcesStore is the persistent interface for login sources.
|
||||
//
|
||||
// NOTE: All methods are sorted in alphabetical order.
|
||||
type LoginSourcesStore interface {
|
||||
// Create creates a new login source and persist to database.
|
||||
// It returns ErrLoginSourceAlreadyExist when a login source with same name already exists.
|
||||
Create(opts CreateLoginSourceOpts) (*LoginSource, error)
|
||||
// Count returns the total number of login sources.
|
||||
Count() int64
|
||||
// DeleteByID deletes a login source by given ID.
|
||||
// It returns ErrLoginSourceInUse if at least one user is associated with the login source.
|
||||
DeleteByID(id int64) error
|
||||
// GetByID returns the login source with given ID.
|
||||
// It returns ErrLoginSourceNotExist when not found.
|
||||
GetByID(id int64) (*LoginSource, error)
|
||||
// List returns a list of login sources filtered by options.
|
||||
List(opts ListLoginSourceOpts) ([]*LoginSource, error)
|
||||
// ResetNonDefault clears default flag for all the other login sources.
|
||||
ResetNonDefault(source *LoginSource) error
|
||||
// Save persists all values of given login source to database or local file.
|
||||
// The Updated field is set to current time automatically.
|
||||
Save(t *LoginSource) error
|
||||
}
|
||||
|
||||
var LoginSources LoginSourcesStore
|
||||
|
||||
// LoginSource represents an external way for authorizing users.
|
||||
type LoginSource struct {
|
||||
ID int64
|
||||
Type LoginType
|
||||
Name string `xorm:"UNIQUE"`
|
||||
IsActived bool `xorm:"NOT NULL DEFAULT false"`
|
||||
IsDefault bool `xorm:"DEFAULT false"`
|
||||
Config interface{} `xorm:"-" gorm:"-"`
|
||||
RawConfig string `xorm:"TEXT cfg" gorm:"COLUMN:cfg"`
|
||||
|
||||
Created time.Time `xorm:"-" gorm:"-" json:"-"`
|
||||
CreatedUnix int64
|
||||
Updated time.Time `xorm:"-" gorm:"-" json:"-"`
|
||||
UpdatedUnix int64
|
||||
|
||||
File loginSourceFileStore `xorm:"-" gorm:"-" json:"-"`
|
||||
}
|
||||
|
||||
// NOTE: This is a GORM save hook.
|
||||
func (s *LoginSource) BeforeSave() (err error) {
|
||||
s.RawConfig, err = jsoniter.MarshalToString(s.Config)
|
||||
return err
|
||||
}
|
||||
|
||||
// NOTE: This is a GORM create hook.
|
||||
func (s *LoginSource) BeforeCreate() {
|
||||
s.CreatedUnix = gorm.NowFunc().Unix()
|
||||
s.UpdatedUnix = s.CreatedUnix
|
||||
}
|
||||
|
||||
// NOTE: This is a GORM update hook.
|
||||
func (s *LoginSource) BeforeUpdate() {
|
||||
s.UpdatedUnix = gorm.NowFunc().Unix()
|
||||
}
|
||||
|
||||
// NOTE: This is a GORM query hook.
|
||||
func (s *LoginSource) AfterFind() error {
|
||||
s.Created = time.Unix(s.CreatedUnix, 0).Local()
|
||||
s.Updated = time.Unix(s.UpdatedUnix, 0).Local()
|
||||
|
||||
switch s.Type {
|
||||
case LoginLDAP, LoginDLDAP:
|
||||
s.Config = new(LDAPConfig)
|
||||
case LoginSMTP:
|
||||
s.Config = new(SMTPConfig)
|
||||
case LoginPAM:
|
||||
s.Config = new(PAMConfig)
|
||||
case LoginGitHub:
|
||||
s.Config = new(GitHubConfig)
|
||||
default:
|
||||
return fmt.Errorf("unrecognized login source type: %v", s.Type)
|
||||
}
|
||||
return jsoniter.UnmarshalFromString(s.RawConfig, s.Config)
|
||||
}
|
||||
|
||||
func (s *LoginSource) TypeName() string {
|
||||
return LoginNames[s.Type]
|
||||
}
|
||||
|
||||
func (s *LoginSource) IsLDAP() bool {
|
||||
return s.Type == LoginLDAP
|
||||
}
|
||||
|
||||
func (s *LoginSource) IsDLDAP() bool {
|
||||
return s.Type == LoginDLDAP
|
||||
}
|
||||
|
||||
func (s *LoginSource) IsSMTP() bool {
|
||||
return s.Type == LoginSMTP
|
||||
}
|
||||
|
||||
func (s *LoginSource) IsPAM() bool {
|
||||
return s.Type == LoginPAM
|
||||
}
|
||||
|
||||
func (s *LoginSource) IsGitHub() bool {
|
||||
return s.Type == LoginGitHub
|
||||
}
|
||||
|
||||
func (s *LoginSource) HasTLS() bool {
|
||||
return ((s.IsLDAP() || s.IsDLDAP()) &&
|
||||
s.LDAP().SecurityProtocol > ldap.SecurityProtocolUnencrypted) ||
|
||||
s.IsSMTP()
|
||||
}
|
||||
|
||||
func (s *LoginSource) UseTLS() bool {
|
||||
switch s.Type {
|
||||
case LoginLDAP, LoginDLDAP:
|
||||
return s.LDAP().SecurityProtocol != ldap.SecurityProtocolUnencrypted
|
||||
case LoginSMTP:
|
||||
return s.SMTP().TLS
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *LoginSource) SkipVerify() bool {
|
||||
switch s.Type {
|
||||
case LoginLDAP, LoginDLDAP:
|
||||
return s.LDAP().SkipVerify
|
||||
case LoginSMTP:
|
||||
return s.SMTP().SkipVerify
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *LoginSource) LDAP() *LDAPConfig {
|
||||
return s.Config.(*LDAPConfig)
|
||||
}
|
||||
|
||||
func (s *LoginSource) SMTP() *SMTPConfig {
|
||||
return s.Config.(*SMTPConfig)
|
||||
}
|
||||
|
||||
func (s *LoginSource) PAM() *PAMConfig {
|
||||
return s.Config.(*PAMConfig)
|
||||
}
|
||||
|
||||
func (s *LoginSource) GitHub() *GitHubConfig {
|
||||
return s.Config.(*GitHubConfig)
|
||||
}
|
||||
|
||||
var _ LoginSourcesStore = (*loginSources)(nil)
|
||||
|
||||
type loginSources struct {
|
||||
*gorm.DB
|
||||
files loginSourceFilesStore
|
||||
}
|
||||
|
||||
type CreateLoginSourceOpts struct {
|
||||
Type LoginType
|
||||
Name string
|
||||
Activated bool
|
||||
Default bool
|
||||
Config interface{}
|
||||
}
|
||||
|
||||
type ErrLoginSourceAlreadyExist struct {
|
||||
args errutil.Args
|
||||
}
|
||||
|
||||
func IsErrLoginSourceAlreadyExist(err error) bool {
|
||||
_, ok := err.(ErrLoginSourceAlreadyExist)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrLoginSourceAlreadyExist) Error() string {
|
||||
return fmt.Sprintf("login source already exists: %v", err.args)
|
||||
}
|
||||
|
||||
func (db *loginSources) Create(opts CreateLoginSourceOpts) (*LoginSource, error) {
|
||||
err := db.Where("name = ?", opts.Name).First(new(LoginSource)).Error
|
||||
if err == nil {
|
||||
return nil, ErrLoginSourceAlreadyExist{args: errutil.Args{"name": opts.Name}}
|
||||
} else if !gorm.IsRecordNotFoundError(err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
source := &LoginSource{
|
||||
Type: opts.Type,
|
||||
Name: opts.Name,
|
||||
IsActived: opts.Activated,
|
||||
IsDefault: opts.Default,
|
||||
Config: opts.Config,
|
||||
}
|
||||
return source, db.DB.Create(source).Error
|
||||
}
|
||||
|
||||
func (db *loginSources) Count() int64 {
|
||||
var count int64
|
||||
db.Model(new(LoginSource)).Count(&count)
|
||||
return count + int64(db.files.Len())
|
||||
}
|
||||
|
||||
type ErrLoginSourceInUse struct {
|
||||
args errutil.Args
|
||||
}
|
||||
|
||||
func IsErrLoginSourceInUse(err error) bool {
|
||||
_, ok := err.(ErrLoginSourceInUse)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrLoginSourceInUse) Error() string {
|
||||
return fmt.Sprintf("login source is still used by some users: %v", err.args)
|
||||
}
|
||||
|
||||
func (db *loginSources) DeleteByID(id int64) error {
|
||||
var count int64
|
||||
err := db.Model(new(User)).Where("login_source = ?", id).Count(&count).Error
|
||||
if err != nil {
|
||||
return err
|
||||
} else if count > 0 {
|
||||
return ErrLoginSourceInUse{args: errutil.Args{"id": id}}
|
||||
}
|
||||
|
||||
return db.Where("id = ?", id).Delete(new(LoginSource)).Error
|
||||
}
|
||||
|
||||
func (db *loginSources) GetByID(id int64) (*LoginSource, error) {
|
||||
|
@ -28,9 +248,63 @@ func (db *loginSources) GetByID(id int64) (*LoginSource, error) {
|
|||
err := db.Where("id = ?", id).First(source).Error
|
||||
if err != nil {
|
||||
if gorm.IsRecordNotFoundError(err) {
|
||||
return localLoginSources.GetLoginSourceByID(id)
|
||||
return db.files.GetByID(id)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return source, nil
|
||||
}
|
||||
|
||||
type ListLoginSourceOpts struct {
|
||||
// Whether to only include activated login sources.
|
||||
OnlyActivated bool
|
||||
}
|
||||
|
||||
func (db *loginSources) List(opts ListLoginSourceOpts) ([]*LoginSource, error) {
|
||||
var sources []*LoginSource
|
||||
query := db.Order("id ASC")
|
||||
if opts.OnlyActivated {
|
||||
query = query.Where("is_actived = ?", true)
|
||||
}
|
||||
err := query.Find(&sources).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return append(sources, db.files.List(opts)...), nil
|
||||
}
|
||||
|
||||
func (db *loginSources) ResetNonDefault(dflt *LoginSource) error {
|
||||
err := db.Model(new(LoginSource)).Where("id != ?", dflt.ID).Updates(map[string]interface{}{"is_default": false}).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, source := range db.files.List(ListLoginSourceOpts{}) {
|
||||
if source.File != nil && source.ID != dflt.ID {
|
||||
source.File.SetGeneral("is_default", "false")
|
||||
if err = source.File.Save(); err != nil {
|
||||
return errors.Wrap(err, "save file")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
db.files.Update(dflt)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *loginSources) Save(source *LoginSource) error {
|
||||
if source.File == nil {
|
||||
return db.DB.Save(source).Error
|
||||
}
|
||||
|
||||
source.File.SetGeneral("name", source.Name)
|
||||
source.File.SetGeneral("is_activated", strconv.FormatBool(source.IsActived))
|
||||
source.File.SetGeneral("is_default", strconv.FormatBool(source.IsDefault))
|
||||
if err := source.File.SetConfig(source.Config); err != nil {
|
||||
return errors.Wrap(err, "set config")
|
||||
} else if err = source.File.Save(); err != nil {
|
||||
return errors.Wrap(err, "save file")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,389 @@
|
|||
// Copyright 2020 The Gogs Authors. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"gogs.io/gogs/internal/errutil"
|
||||
)
|
||||
|
||||
func Test_loginSources(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
t.Parallel()
|
||||
|
||||
tables := []interface{}{new(LoginSource), new(User)}
|
||||
db := &loginSources{
|
||||
DB: initTestDB(t, "loginSources", tables...),
|
||||
}
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
test func(*testing.T, *loginSources)
|
||||
}{
|
||||
{"Create", test_loginSources_Create},
|
||||
{"Count", test_loginSources_Count},
|
||||
{"DeleteByID", test_loginSources_DeleteByID},
|
||||
{"GetByID", test_loginSources_GetByID},
|
||||
{"List", test_loginSources_List},
|
||||
{"ResetNonDefault", test_loginSources_ResetNonDefault},
|
||||
{"Save", test_loginSources_Save},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
err := clearTables(db.DB, tables...)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
tc.test(t, db)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func test_loginSources_Create(t *testing.T, db *loginSources) {
|
||||
// Create first login source with name "GitHub"
|
||||
source, err := db.Create(CreateLoginSourceOpts{
|
||||
Type: LoginGitHub,
|
||||
Name: "GitHub",
|
||||
Activated: true,
|
||||
Default: false,
|
||||
Config: &GitHubConfig{
|
||||
APIEndpoint: "https://api.github.com",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Get it back and check the Created field
|
||||
source, err = db.GetByID(source.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.Equal(t, gorm.NowFunc().Format(time.RFC3339), source.Created.Format(time.RFC3339))
|
||||
assert.Equal(t, gorm.NowFunc().Format(time.RFC3339), source.Updated.Format(time.RFC3339))
|
||||
|
||||
// Try create second login source with same name should fail
|
||||
_, err = db.Create(CreateLoginSourceOpts{Name: source.Name})
|
||||
expErr := ErrLoginSourceAlreadyExist{args: errutil.Args{"name": source.Name}}
|
||||
assert.Equal(t, expErr, err)
|
||||
}
|
||||
|
||||
func test_loginSources_Count(t *testing.T, db *loginSources) {
|
||||
// Create two login sources, one in database and one as source file.
|
||||
_, err := db.Create(CreateLoginSourceOpts{
|
||||
Type: LoginGitHub,
|
||||
Name: "GitHub",
|
||||
Activated: true,
|
||||
Default: false,
|
||||
Config: &GitHubConfig{
|
||||
APIEndpoint: "https://api.github.com",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
|
||||
MockLen: func() int {
|
||||
return 2
|
||||
},
|
||||
})
|
||||
|
||||
assert.Equal(t, int64(3), db.Count())
|
||||
}
|
||||
|
||||
func test_loginSources_DeleteByID(t *testing.T, db *loginSources) {
|
||||
t.Run("delete but in used", func(t *testing.T) {
|
||||
source, err := db.Create(CreateLoginSourceOpts{
|
||||
Type: LoginGitHub,
|
||||
Name: "GitHub",
|
||||
Activated: true,
|
||||
Default: false,
|
||||
Config: &GitHubConfig{
|
||||
APIEndpoint: "https://api.github.com",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a user that uses this login source
|
||||
user := &User{
|
||||
LoginSource: source.ID,
|
||||
}
|
||||
err = db.DB.Create(user).Error
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Delete the login source will result in error
|
||||
err = db.DeleteByID(source.ID)
|
||||
expErr := ErrLoginSourceInUse{args: errutil.Args{"id": source.ID}}
|
||||
assert.Equal(t, expErr, err)
|
||||
})
|
||||
|
||||
setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
|
||||
MockGetByID: func(id int64) (*LoginSource, error) {
|
||||
return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
|
||||
},
|
||||
})
|
||||
|
||||
// Create a login source with name "GitHub2"
|
||||
source, err := db.Create(CreateLoginSourceOpts{
|
||||
Type: LoginGitHub,
|
||||
Name: "GitHub2",
|
||||
Activated: true,
|
||||
Default: false,
|
||||
Config: &GitHubConfig{
|
||||
APIEndpoint: "https://api.github.com",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Delete a non-existent ID is noop
|
||||
err = db.DeleteByID(9999)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// We should be able to get it back
|
||||
_, err = db.GetByID(source.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Now delete this login source with ID
|
||||
err = db.DeleteByID(source.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// We should get token not found error
|
||||
_, err = db.GetByID(source.ID)
|
||||
expErr := ErrLoginSourceNotExist{args: errutil.Args{"id": source.ID}}
|
||||
assert.Equal(t, expErr, err)
|
||||
}
|
||||
|
||||
func test_loginSources_GetByID(t *testing.T, db *loginSources) {
|
||||
setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
|
||||
MockGetByID: func(id int64) (*LoginSource, error) {
|
||||
if id != 101 {
|
||||
return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
|
||||
}
|
||||
return &LoginSource{ID: id}, nil
|
||||
},
|
||||
})
|
||||
|
||||
expConfig := &GitHubConfig{
|
||||
APIEndpoint: "https://api.github.com",
|
||||
}
|
||||
|
||||
// Create a login source with name "GitHub"
|
||||
source, err := db.Create(CreateLoginSourceOpts{
|
||||
Type: LoginGitHub,
|
||||
Name: "GitHub",
|
||||
Activated: true,
|
||||
Default: false,
|
||||
Config: expConfig,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Get the one in the database and test the read/write hooks
|
||||
source, err = db.GetByID(source.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.Equal(t, expConfig, source.Config)
|
||||
|
||||
// Get the one in source file store
|
||||
_, err = db.GetByID(101)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func test_loginSources_List(t *testing.T, db *loginSources) {
|
||||
setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
|
||||
MockList: func(opts ListLoginSourceOpts) []*LoginSource {
|
||||
if opts.OnlyActivated {
|
||||
return []*LoginSource{
|
||||
{ID: 1},
|
||||
}
|
||||
}
|
||||
return []*LoginSource{
|
||||
{ID: 1},
|
||||
{ID: 2},
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
// Create two login sources in database, one activated and the other one not
|
||||
_, err := db.Create(CreateLoginSourceOpts{
|
||||
Type: LoginPAM,
|
||||
Name: "PAM",
|
||||
Config: &PAMConfig{
|
||||
ServiceName: "PAM",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = db.Create(CreateLoginSourceOpts{
|
||||
Type: LoginGitHub,
|
||||
Name: "GitHub",
|
||||
Activated: true,
|
||||
Config: &GitHubConfig{
|
||||
APIEndpoint: "https://api.github.com",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// List all login sources
|
||||
sources, err := db.List(ListLoginSourceOpts{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.Equal(t, 4, len(sources), "number of sources")
|
||||
|
||||
// Only list activated login sources
|
||||
sources, err = db.List(ListLoginSourceOpts{OnlyActivated: true})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.Equal(t, 2, len(sources), "number of sources")
|
||||
}
|
||||
|
||||
func test_loginSources_ResetNonDefault(t *testing.T, db *loginSources) {
|
||||
setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
|
||||
MockList: func(opts ListLoginSourceOpts) []*LoginSource {
|
||||
return []*LoginSource{
|
||||
{
|
||||
File: &mockLoginSourceFileStore{
|
||||
MockSetGeneral: func(name, value string) {
|
||||
assert.Equal(t, "is_default", name)
|
||||
assert.Equal(t, "false", value)
|
||||
},
|
||||
MockSave: func() error {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
MockUpdate: func(source *LoginSource) {},
|
||||
})
|
||||
|
||||
// Create two login sources both have default on
|
||||
source1, err := db.Create(CreateLoginSourceOpts{
|
||||
Type: LoginPAM,
|
||||
Name: "PAM",
|
||||
Default: true,
|
||||
Config: &PAMConfig{
|
||||
ServiceName: "PAM",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
source2, err := db.Create(CreateLoginSourceOpts{
|
||||
Type: LoginGitHub,
|
||||
Name: "GitHub",
|
||||
Activated: true,
|
||||
Default: true,
|
||||
Config: &GitHubConfig{
|
||||
APIEndpoint: "https://api.github.com",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Set source 1 as default
|
||||
err = db.ResetNonDefault(source1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Verify the default state
|
||||
source1, err = db.GetByID(source1.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.True(t, source1.IsDefault)
|
||||
|
||||
source2, err = db.GetByID(source2.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.False(t, source2.IsDefault)
|
||||
}
|
||||
|
||||
func test_loginSources_Save(t *testing.T, db *loginSources) {
|
||||
t.Run("save to database", func(t *testing.T) {
|
||||
// Create a login source with name "GitHub"
|
||||
source, err := db.Create(CreateLoginSourceOpts{
|
||||
Type: LoginGitHub,
|
||||
Name: "GitHub",
|
||||
Activated: true,
|
||||
Default: false,
|
||||
Config: &GitHubConfig{
|
||||
APIEndpoint: "https://api.github.com",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
source.IsActived = false
|
||||
source.Config = &GitHubConfig{
|
||||
APIEndpoint: "https://api2.github.com",
|
||||
}
|
||||
err = db.Save(source)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
source, err = db.GetByID(source.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.False(t, source.IsActived)
|
||||
assert.Equal(t, "https://api2.github.com", source.GitHub().APIEndpoint)
|
||||
})
|
||||
|
||||
t.Run("save to file", func(t *testing.T) {
|
||||
calledSave := false
|
||||
source := &LoginSource{
|
||||
File: &mockLoginSourceFileStore{
|
||||
MockSetGeneral: func(name, value string) {},
|
||||
MockSetConfig: func(cfg interface{}) error { return nil },
|
||||
MockSave: func() error {
|
||||
calledSave = true
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
err := db.Save(source)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.True(t, calledSave)
|
||||
})
|
||||
}
|
|
@ -41,7 +41,8 @@ func TestMain(m *testing.M) {
|
|||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func deleteTables(db *gorm.DB, tables ...interface{}) error {
|
||||
// clearTables removes all rows from given tables.
|
||||
func clearTables(db *gorm.DB, tables ...interface{}) error {
|
||||
for _, t := range tables {
|
||||
err := db.Delete(t).Error
|
||||
if err != nil {
|
||||
|
|
|
@ -78,6 +78,59 @@ func SetMockLFSStore(t *testing.T, mock LFSStore) {
|
|||
})
|
||||
}
|
||||
|
||||
var _ loginSourceFilesStore = (*mockLoginSourceFilesStore)(nil)
|
||||
|
||||
type mockLoginSourceFilesStore struct {
|
||||
MockGetByID func(id int64) (*LoginSource, error)
|
||||
MockLen func() int
|
||||
MockList func(opts ListLoginSourceOpts) []*LoginSource
|
||||
MockUpdate func(source *LoginSource)
|
||||
}
|
||||
|
||||
func (m *mockLoginSourceFilesStore) GetByID(id int64) (*LoginSource, error) {
|
||||
return m.MockGetByID(id)
|
||||
}
|
||||
|
||||
func (m *mockLoginSourceFilesStore) Len() int {
|
||||
return m.MockLen()
|
||||
}
|
||||
|
||||
func (m *mockLoginSourceFilesStore) List(opts ListLoginSourceOpts) []*LoginSource {
|
||||
return m.MockList(opts)
|
||||
}
|
||||
|
||||
func (m *mockLoginSourceFilesStore) Update(source *LoginSource) {
|
||||
m.MockUpdate(source)
|
||||
}
|
||||
|
||||
func setMockLoginSourceFilesStore(t *testing.T, db *loginSources, mock loginSourceFilesStore) {
|
||||
before := db.files
|
||||
db.files = mock
|
||||
t.Cleanup(func() {
|
||||
db.files = before
|
||||
})
|
||||
}
|
||||
|
||||
var _ loginSourceFileStore = (*mockLoginSourceFileStore)(nil)
|
||||
|
||||
type mockLoginSourceFileStore struct {
|
||||
MockSetGeneral func(name, value string)
|
||||
MockSetConfig func(cfg interface{}) error
|
||||
MockSave func() error
|
||||
}
|
||||
|
||||
func (m *mockLoginSourceFileStore) SetGeneral(name, value string) {
|
||||
m.MockSetGeneral(name, value)
|
||||
}
|
||||
|
||||
func (m *mockLoginSourceFileStore) SetConfig(cfg interface{}) error {
|
||||
return m.MockSetConfig(cfg)
|
||||
}
|
||||
|
||||
func (m *mockLoginSourceFileStore) Save() error {
|
||||
return m.MockSave()
|
||||
}
|
||||
|
||||
var _ PermsStore = (*MockPermsStore)(nil)
|
||||
|
||||
type MockPermsStore struct {
|
||||
|
|
|
@ -53,7 +53,7 @@ func init() {
|
|||
new(Watch), new(Star), new(Follow), new(Action),
|
||||
new(Issue), new(PullRequest), new(Comment), new(Attachment), new(IssueUser),
|
||||
new(Label), new(IssueLabel), new(Milestone),
|
||||
new(Mirror), new(Release), new(LoginSource), new(Webhook), new(HookTask),
|
||||
new(Mirror), new(Release), new(Webhook), new(HookTask),
|
||||
new(ProtectBranch), new(ProtectBranchWhitelist),
|
||||
new(Team), new(OrgUser), new(TeamUser), new(TeamRepo),
|
||||
new(Notice), new(EmailAddress))
|
||||
|
@ -200,7 +200,7 @@ func GetStatistic() (stats Statistic) {
|
|||
stats.Counter.Follow, _ = x.Count(new(Follow))
|
||||
stats.Counter.Mirror, _ = x.Count(new(Mirror))
|
||||
stats.Counter.Release, _ = x.Count(new(Release))
|
||||
stats.Counter.LoginSource = CountLoginSources()
|
||||
stats.Counter.LoginSource = LoginSources.Count()
|
||||
stats.Counter.Webhook, _ = x.Count(new(Webhook))
|
||||
stats.Counter.Milestone, _ = x.Count(new(Milestone))
|
||||
stats.Counter.Label, _ = x.Count(new(Label))
|
||||
|
@ -295,13 +295,13 @@ func ImportDatabase(dirPath string, verbose bool) (err error) {
|
|||
tp := LoginType(com.StrTo(com.ToStr(meta["Type"])).MustInt64())
|
||||
switch tp {
|
||||
case LoginLDAP, LoginDLDAP:
|
||||
bean.Cfg = new(LDAPConfig)
|
||||
bean.Config = new(LDAPConfig)
|
||||
case LoginSMTP:
|
||||
bean.Cfg = new(SMTPConfig)
|
||||
bean.Config = new(SMTPConfig)
|
||||
case LoginPAM:
|
||||
bean.Cfg = new(PAMConfig)
|
||||
bean.Config = new(PAMConfig)
|
||||
case LoginGitHub:
|
||||
bean.Cfg = new(GitHubConfig)
|
||||
bean.Config = new(GitHubConfig)
|
||||
default:
|
||||
return fmt.Errorf("unrecognized login source type:: %v", tp)
|
||||
}
|
||||
|
|
|
@ -17,8 +17,9 @@ func Test_perms(t *testing.T) {
|
|||
|
||||
t.Parallel()
|
||||
|
||||
tables := []interface{}{new(Access)}
|
||||
db := &perms{
|
||||
DB: initTestDB(t, "perms", new(Access)),
|
||||
DB: initTestDB(t, "perms", tables...),
|
||||
}
|
||||
|
||||
for _, tc := range []struct {
|
||||
|
@ -31,7 +32,7 @@ func Test_perms(t *testing.T) {
|
|||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
err := deleteTables(db.DB, new(Access))
|
||||
err := clearTables(db.DB, tables...)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -48,33 +48,33 @@ const (
|
|||
// User represents the object of individual and member of organization.
|
||||
type User struct {
|
||||
ID int64
|
||||
LowerName string `xorm:"UNIQUE NOT NULL"`
|
||||
Name string `xorm:"UNIQUE NOT NULL"`
|
||||
LowerName string `xorm:"UNIQUE NOT NULL" gorm:"UNIQUE"`
|
||||
Name string `xorm:"UNIQUE NOT NULL" gorm:"NOT NULL"`
|
||||
FullName string
|
||||
// Email is the primary email address (to be used for communication)
|
||||
Email string `xorm:"NOT NULL"`
|
||||
Passwd string `xorm:"NOT NULL"`
|
||||
Email string `xorm:"NOT NULL" gorm:"NOT NULL"`
|
||||
Passwd string `xorm:"NOT NULL" gorm:"NOT NULL"`
|
||||
LoginType LoginType
|
||||
LoginSource int64 `xorm:"NOT NULL DEFAULT 0"`
|
||||
LoginSource int64 `xorm:"NOT NULL DEFAULT 0" gorm:"NOT NULL;DEFAULT:0"`
|
||||
LoginName string
|
||||
Type UserType
|
||||
OwnedOrgs []*User `xorm:"-" json:"-"`
|
||||
Orgs []*User `xorm:"-" json:"-"`
|
||||
Repos []*Repository `xorm:"-" json:"-"`
|
||||
OwnedOrgs []*User `xorm:"-" gorm:"-" json:"-"`
|
||||
Orgs []*User `xorm:"-" gorm:"-" json:"-"`
|
||||
Repos []*Repository `xorm:"-" gorm:"-" json:"-"`
|
||||
Location string
|
||||
Website string
|
||||
Rands string `xorm:"VARCHAR(10)"`
|
||||
Salt string `xorm:"VARCHAR(10)"`
|
||||
Rands string `xorm:"VARCHAR(10)" gorm:"TYPE:VARCHAR(10)"`
|
||||
Salt string `xorm:"VARCHAR(10)" gorm:"TYPE:VARCHAR(10)"`
|
||||
|
||||
Created time.Time `xorm:"-" json:"-"`
|
||||
Created time.Time `xorm:"-" gorm:"-" json:"-"`
|
||||
CreatedUnix int64
|
||||
Updated time.Time `xorm:"-" json:"-"`
|
||||
Updated time.Time `xorm:"-" gorm:"-" json:"-"`
|
||||
UpdatedUnix int64
|
||||
|
||||
// Remember visibility choice for convenience, true for private
|
||||
LastRepoVisibility bool
|
||||
// Maximum repository creation limit, -1 means use gloabl default
|
||||
MaxRepoCreation int `xorm:"NOT NULL DEFAULT -1"`
|
||||
// Maximum repository creation limit, -1 means use global default
|
||||
MaxRepoCreation int `xorm:"NOT NULL DEFAULT -1" gorm:"NOT NULL;DEFAULT:-1"`
|
||||
|
||||
// Permissions
|
||||
IsActive bool // Activate primary email
|
||||
|
@ -84,13 +84,13 @@ type User struct {
|
|||
ProhibitLogin bool
|
||||
|
||||
// Avatar
|
||||
Avatar string `xorm:"VARCHAR(2048) NOT NULL"`
|
||||
AvatarEmail string `xorm:"NOT NULL"`
|
||||
Avatar string `xorm:"VARCHAR(2048) NOT NULL" gorm:"TYPE:VARCHAR(2048);NOT NULL"`
|
||||
AvatarEmail string `xorm:"NOT NULL" gorm:"NOT NULL"`
|
||||
UseCustomAvatar bool
|
||||
|
||||
// Counters
|
||||
NumFollowers int
|
||||
NumFollowing int `xorm:"NOT NULL DEFAULT 0"`
|
||||
NumFollowing int `xorm:"NOT NULL DEFAULT 0" gorm:"NOT NULL;DEFAULT:0"`
|
||||
NumStars int
|
||||
NumRepos int
|
||||
|
||||
|
@ -98,8 +98,8 @@ type User struct {
|
|||
Description string
|
||||
NumTeams int
|
||||
NumMembers int
|
||||
Teams []*Team `xorm:"-" json:"-"`
|
||||
Members []*User `xorm:"-" json:"-"`
|
||||
Teams []*Team `xorm:"-" gorm:"-" json:"-"`
|
||||
Members []*User `xorm:"-" gorm:"-" json:"-"`
|
||||
}
|
||||
|
||||
func (u *User) BeforeInsert() {
|
||||
|
|
|
@ -29,6 +29,8 @@ func (w *Writer) Print(v ...interface{}) {
|
|||
fmt.Fprintf(w.Writer, "[sql] [%s] [%s] %s %v (%d rows affected)", v[1:]...)
|
||||
case "log":
|
||||
fmt.Fprintf(w.Writer, "[log] [%s] %s", v[1:]...)
|
||||
case "error":
|
||||
fmt.Fprintf(w.Writer, "[err] [%s] %s", v[1:]...)
|
||||
default:
|
||||
fmt.Fprint(w.Writer, v...)
|
||||
}
|
||||
|
|
|
@ -41,6 +41,11 @@ func TestWriter_Print(t *testing.T) {
|
|||
vs: []interface{}{"log", "writer.go:65", "something"},
|
||||
expOutput: "[log] [writer.go:65] something",
|
||||
},
|
||||
{
|
||||
name: "error",
|
||||
vs: []interface{}{"error", "writer.go:65", "something bad"},
|
||||
expOutput: "[err] [writer.go:65] something bad",
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
|
|
|
@ -17,6 +17,16 @@ func IsFile(path string) bool {
|
|||
return !f.IsDir()
|
||||
}
|
||||
|
||||
// IsDir returns true if given path is a directory, and returns false when it's
|
||||
// a file or does not exist.
|
||||
func IsDir(dir string) bool {
|
||||
f, e := os.Stat(dir)
|
||||
if e != nil {
|
||||
return false
|
||||
}
|
||||
return f.IsDir()
|
||||
}
|
||||
|
||||
// IsExist returns true if a file or directory exists.
|
||||
func IsExist(path string) bool {
|
||||
_, err := os.Stat(path)
|
||||
|
|
|
@ -33,6 +33,29 @@ func TestIsFile(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestIsDir(t *testing.T) {
|
||||
tests := []struct {
|
||||
path string
|
||||
expVal bool
|
||||
}{
|
||||
{
|
||||
path: "osutil.go",
|
||||
expVal: false,
|
||||
}, {
|
||||
path: "../osutil",
|
||||
expVal: true,
|
||||
}, {
|
||||
path: "not_found",
|
||||
expVal: false,
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run("", func(t *testing.T) {
|
||||
assert.Equal(t, test.expVal, IsDir(test.path))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsExist(t *testing.T) {
|
||||
tests := []struct {
|
||||
path string
|
||||
|
|
|
@ -11,7 +11,6 @@ import (
|
|||
|
||||
"github.com/unknwon/com"
|
||||
log "unknwon.dev/clog/v2"
|
||||
"xorm.io/core"
|
||||
|
||||
"gogs.io/gogs/internal/auth/ldap"
|
||||
"gogs.io/gogs/internal/conf"
|
||||
|
@ -32,13 +31,13 @@ func Authentications(c *context.Context) {
|
|||
c.PageIs("AdminAuthentications")
|
||||
|
||||
var err error
|
||||
c.Data["Sources"], err = db.ListLoginSources()
|
||||
c.Data["Sources"], err = db.LoginSources.List(db.ListLoginSourceOpts{})
|
||||
if err != nil {
|
||||
c.Error(err, "list login sources")
|
||||
return
|
||||
}
|
||||
|
||||
c.Data["Total"] = db.CountLoginSources()
|
||||
c.Data["Total"] = db.LoginSources.Count()
|
||||
c.Success(AUTHS)
|
||||
}
|
||||
|
||||
|
@ -56,9 +55,9 @@ var (
|
|||
{db.LoginNames[db.LoginGitHub], db.LoginGitHub},
|
||||
}
|
||||
securityProtocols = []dropdownItem{
|
||||
{db.SecurityProtocolNames[ldap.SECURITY_PROTOCOL_UNENCRYPTED], ldap.SECURITY_PROTOCOL_UNENCRYPTED},
|
||||
{db.SecurityProtocolNames[ldap.SECURITY_PROTOCOL_LDAPS], ldap.SECURITY_PROTOCOL_LDAPS},
|
||||
{db.SecurityProtocolNames[ldap.SECURITY_PROTOCOL_START_TLS], ldap.SECURITY_PROTOCOL_START_TLS},
|
||||
{db.SecurityProtocolNames[ldap.SecurityProtocolUnencrypted], ldap.SecurityProtocolUnencrypted},
|
||||
{db.SecurityProtocolNames[ldap.SecurityProtocolLDAPS], ldap.SecurityProtocolLDAPS},
|
||||
{db.SecurityProtocolNames[ldap.SecurityProtocolStartTLS], ldap.SecurityProtocolStartTLS},
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -69,7 +68,7 @@ func NewAuthSource(c *context.Context) {
|
|||
|
||||
c.Data["type"] = db.LoginLDAP
|
||||
c.Data["CurrentTypeName"] = db.LoginNames[db.LoginLDAP]
|
||||
c.Data["CurrentSecurityProtocol"] = db.SecurityProtocolNames[ldap.SECURITY_PROTOCOL_UNENCRYPTED]
|
||||
c.Data["CurrentSecurityProtocol"] = db.SecurityProtocolNames[ldap.SecurityProtocolUnencrypted]
|
||||
c.Data["smtp_auth"] = "PLAIN"
|
||||
c.Data["is_active"] = true
|
||||
c.Data["is_default"] = true
|
||||
|
@ -81,7 +80,7 @@ func NewAuthSource(c *context.Context) {
|
|||
|
||||
func parseLDAPConfig(f form.Authentication) *db.LDAPConfig {
|
||||
return &db.LDAPConfig{
|
||||
Source: &ldap.Source{
|
||||
Source: ldap.Source{
|
||||
Host: f.Host,
|
||||
Port: f.Port,
|
||||
SecurityProtocol: ldap.SecurityProtocol(f.SecurityProtocol),
|
||||
|
@ -129,11 +128,11 @@ func NewAuthSourcePost(c *context.Context, f form.Authentication) {
|
|||
c.Data["SMTPAuths"] = db.SMTPAuths
|
||||
|
||||
hasTLS := false
|
||||
var config core.Conversion
|
||||
var config interface{}
|
||||
switch db.LoginType(f.Type) {
|
||||
case db.LoginLDAP, db.LoginDLDAP:
|
||||
config = parseLDAPConfig(f)
|
||||
hasTLS = ldap.SecurityProtocol(f.SecurityProtocol) > ldap.SECURITY_PROTOCOL_UNENCRYPTED
|
||||
hasTLS = ldap.SecurityProtocol(f.SecurityProtocol) > ldap.SecurityProtocolUnencrypted
|
||||
case db.LoginSMTP:
|
||||
config = parseSMTPConfig(f)
|
||||
hasTLS = true
|
||||
|
@ -156,22 +155,31 @@ func NewAuthSourcePost(c *context.Context, f form.Authentication) {
|
|||
return
|
||||
}
|
||||
|
||||
if err := db.CreateLoginSource(&db.LoginSource{
|
||||
source, err := db.LoginSources.Create(db.CreateLoginSourceOpts{
|
||||
Type: db.LoginType(f.Type),
|
||||
Name: f.Name,
|
||||
IsActived: f.IsActive,
|
||||
IsDefault: f.IsDefault,
|
||||
Cfg: config,
|
||||
}); err != nil {
|
||||
Activated: f.IsActive,
|
||||
Default: f.IsDefault,
|
||||
Config: config,
|
||||
})
|
||||
if err != nil {
|
||||
if db.IsErrLoginSourceAlreadyExist(err) {
|
||||
c.FormErr("Name")
|
||||
c.RenderWithErr(c.Tr("admin.auths.login_source_exist", err.(db.ErrLoginSourceAlreadyExist).Name), AUTH_NEW, f)
|
||||
c.RenderWithErr(c.Tr("admin.auths.login_source_exist", f.Name), AUTH_NEW, f)
|
||||
} else {
|
||||
c.Error(err, "create login source")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if source.IsDefault {
|
||||
err = db.LoginSources.ResetNonDefault(source)
|
||||
if err != nil {
|
||||
c.Error(err, "reset non-default login sources")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
log.Trace("Authentication created by admin(%s): %s", c.User.Name, f.Name)
|
||||
|
||||
c.Flash.Success(c.Tr("admin.auths.new_success", f.Name))
|
||||
|
@ -217,7 +225,7 @@ func EditAuthSourcePost(c *context.Context, f form.Authentication) {
|
|||
return
|
||||
}
|
||||
|
||||
var config core.Conversion
|
||||
var config interface{}
|
||||
switch db.LoginType(f.Type) {
|
||||
case db.LoginLDAP, db.LoginDLDAP:
|
||||
config = parseLDAPConfig(f)
|
||||
|
@ -239,12 +247,20 @@ func EditAuthSourcePost(c *context.Context, f form.Authentication) {
|
|||
source.Name = f.Name
|
||||
source.IsActived = f.IsActive
|
||||
source.IsDefault = f.IsDefault
|
||||
source.Cfg = config
|
||||
if err := db.UpdateLoginSource(source); err != nil {
|
||||
source.Config = config
|
||||
if err := db.LoginSources.Save(source); err != nil {
|
||||
c.Error(err, "update login source")
|
||||
return
|
||||
}
|
||||
|
||||
if source.IsDefault {
|
||||
err = db.LoginSources.ResetNonDefault(source)
|
||||
if err != nil {
|
||||
c.Error(err, "reset non-default login sources")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
log.Trace("Authentication changed by admin '%s': %d", c.User.Name, source.ID)
|
||||
|
||||
c.Flash.Success(c.Tr("admin.auths.update_success"))
|
||||
|
@ -252,13 +268,8 @@ func EditAuthSourcePost(c *context.Context, f form.Authentication) {
|
|||
}
|
||||
|
||||
func DeleteAuthSource(c *context.Context) {
|
||||
source, err := db.LoginSources.GetByID(c.ParamsInt64(":authid"))
|
||||
if err != nil {
|
||||
c.Error(err, "get login source by ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err = db.DeleteSource(source); err != nil {
|
||||
id := c.ParamsInt64(":authid")
|
||||
if err := db.LoginSources.DeleteByID(id); err != nil {
|
||||
if db.IsErrLoginSourceInUse(err) {
|
||||
c.Flash.Error(c.Tr("admin.auths.still_in_used"))
|
||||
} else {
|
||||
|
@ -269,7 +280,7 @@ func DeleteAuthSource(c *context.Context) {
|
|||
})
|
||||
return
|
||||
}
|
||||
log.Trace("Authentication deleted by admin(%s): %d", c.User.Name, source.ID)
|
||||
log.Trace("Authentication deleted by admin(%s): %d", c.User.Name, id)
|
||||
|
||||
c.Flash.Success(c.Tr("admin.auths.deletion_success"))
|
||||
c.JSONSuccess(map[string]interface{}{
|
||||
|
|
|
@ -46,7 +46,7 @@ func NewUser(c *context.Context) {
|
|||
|
||||
c.Data["login_type"] = "0-0"
|
||||
|
||||
sources, err := db.ListLoginSources()
|
||||
sources, err := db.LoginSources.List(db.ListLoginSourceOpts{})
|
||||
if err != nil {
|
||||
c.Error(err, "list login sources")
|
||||
return
|
||||
|
@ -62,7 +62,7 @@ func NewUserPost(c *context.Context, f form.AdminCrateUser) {
|
|||
c.Data["PageIsAdmin"] = true
|
||||
c.Data["PageIsAdminUsers"] = true
|
||||
|
||||
sources, err := db.ListLoginSources()
|
||||
sources, err := db.LoginSources.List(db.ListLoginSourceOpts{})
|
||||
if err != nil {
|
||||
c.Error(err, "list login sources")
|
||||
return
|
||||
|
@ -141,7 +141,7 @@ func prepareUserInfo(c *context.Context) *db.User {
|
|||
c.Data["LoginSource"] = &db.LoginSource{}
|
||||
}
|
||||
|
||||
sources, err := db.ListLoginSources()
|
||||
sources, err := db.LoginSources.List(db.ListLoginSourceOpts{})
|
||||
if err != nil {
|
||||
c.Error(err, "list login sources")
|
||||
return nil
|
||||
|
|
|
@ -13,7 +13,6 @@ import (
|
|||
"gogs.io/gogs/internal/conf"
|
||||
"gogs.io/gogs/internal/context"
|
||||
"gogs.io/gogs/internal/db"
|
||||
"gogs.io/gogs/internal/db/errors"
|
||||
"gogs.io/gogs/internal/email"
|
||||
"gogs.io/gogs/internal/route/api/v1/user"
|
||||
)
|
||||
|
@ -25,7 +24,7 @@ func parseLoginSource(c *context.APIContext, u *db.User, sourceID int64, loginNa
|
|||
|
||||
source, err := db.LoginSources.GetByID(sourceID)
|
||||
if err != nil {
|
||||
if errors.IsLoginSourceNotExist(err) {
|
||||
if db.IsErrLoginSourceNotExist(err) {
|
||||
c.ErrorStatus(http.StatusUnprocessableEntity, err)
|
||||
} else {
|
||||
c.Error(err, "get login source by ID")
|
||||
|
|
|
@ -76,7 +76,6 @@ func GlobalInit(customConf string) error {
|
|||
}
|
||||
db.HasEngine = true
|
||||
|
||||
db.LoadAuthSources()
|
||||
db.LoadRepoConfig()
|
||||
db.NewRepoContext()
|
||||
|
||||
|
|
|
@ -101,7 +101,7 @@ func Login(c *context.Context) {
|
|||
}
|
||||
|
||||
// Display normal login page
|
||||
loginSources, err := db.ActivatedLoginSources()
|
||||
loginSources, err := db.LoginSources.List(db.ListLoginSourceOpts{OnlyActivated: true})
|
||||
if err != nil {
|
||||
c.Error(err, "list activated login sources")
|
||||
return
|
||||
|
@ -148,7 +148,7 @@ func afterLogin(c *context.Context, u *db.User, remember bool) {
|
|||
func LoginPost(c *context.Context, f form.SignIn) {
|
||||
c.Title("sign_in")
|
||||
|
||||
loginSources, err := db.ActivatedLoginSources()
|
||||
loginSources, err := db.LoginSources.List(db.ListLoginSourceOpts{OnlyActivated: true})
|
||||
if err != nil {
|
||||
c.Error(err, "list activated login sources")
|
||||
return
|
||||
|
|
|
@ -176,7 +176,7 @@
|
|||
<input id="github_api_endpoint" name="github_api_endpoint" value="{{$cfg.APIEndpoint}}" placeholder="e.g. https://api.github.com/" required>
|
||||
</div>
|
||||
{{end}}
|
||||
|
||||
|
||||
<div class="inline field {{if not .Source.IsSMTP}}hide{{end}}">
|
||||
<div class="ui checkbox">
|
||||
<label><strong>{{.i18n.Tr "admin.auths.enable_tls"}}</strong></label>
|
||||
|
@ -203,7 +203,7 @@
|
|||
</div>
|
||||
<div class="field">
|
||||
<button class="ui green button">{{.i18n.Tr "admin.auths.update"}}</button>
|
||||
{{if not .Source.LocalFile}}
|
||||
{{if not .Source.File}}
|
||||
<div class="ui red button delete-button" data-url="{{$.Link}}/delete" data-id="{{.Source.ID}}">{{.i18n.Tr "admin.auths.delete"}}</div>
|
||||
{{end}}
|
||||
</div>
|
||||
|
|
Loading…
Reference in New Issue