gogs/internal/db/login_sources.go

317 lines
7.9 KiB
Go

// Copyright 2020 The Gogs Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package db
import (
"fmt"
"strconv"
"time"
"github.com/jinzhu/gorm"
jsoniter "github.com/json-iterator/go"
"github.com/pkg/errors"
"gogs.io/gogs/internal/auth/ldap"
"gogs.io/gogs/internal/errutil"
)
// LoginSourcesStore is the persistent interface for login sources.
//
// NOTE: All methods are sorted in alphabetical order.
type LoginSourcesStore interface {
// Create creates a new login source and persist to database.
// It returns ErrLoginSourceAlreadyExist when a login source with same name already exists.
Create(opts CreateLoginSourceOpts) (*LoginSource, error)
// Count returns the total number of login sources.
Count() int64
// DeleteByID deletes a login source by given ID.
// It returns ErrLoginSourceInUse if at least one user is associated with the login source.
DeleteByID(id int64) error
// GetByID returns the login source with given ID.
// It returns ErrLoginSourceNotExist when not found.
GetByID(id int64) (*LoginSource, error)
// List returns a list of login sources filtered by options.
List(opts ListLoginSourceOpts) ([]*LoginSource, error)
// ResetNonDefault clears default flag for all the other login sources.
ResetNonDefault(source *LoginSource) error
// Save persists all values of given login source to database or local file.
// The Updated field is set to current time automatically.
Save(t *LoginSource) error
}
var LoginSources LoginSourcesStore
// LoginSource represents an external way for authorizing users.
type LoginSource struct {
ID int64
Type LoginType
Name string `xorm:"UNIQUE" gorm:"UNIQUE"`
IsActived bool `xorm:"NOT NULL DEFAULT false" gorm:"NOT NULL"`
IsDefault bool `xorm:"DEFAULT false"`
Config interface{} `xorm:"-" gorm:"-"`
RawConfig string `xorm:"TEXT cfg" gorm:"COLUMN:cfg"`
Created time.Time `xorm:"-" gorm:"-" json:"-"`
CreatedUnix int64
Updated time.Time `xorm:"-" gorm:"-" json:"-"`
UpdatedUnix int64
File loginSourceFileStore `xorm:"-" gorm:"-" json:"-"`
}
// NOTE: This is a GORM save hook.
func (s *LoginSource) BeforeSave() (err error) {
if s.Config == nil {
return nil
}
s.RawConfig, err = jsoniter.MarshalToString(s.Config)
return err
}
// NOTE: This is a GORM create hook.
func (s *LoginSource) BeforeCreate() {
if s.CreatedUnix > 0 {
return
}
s.CreatedUnix = gorm.NowFunc().Unix()
s.UpdatedUnix = s.CreatedUnix
}
// NOTE: This is a GORM update hook.
func (s *LoginSource) BeforeUpdate() {
s.UpdatedUnix = gorm.NowFunc().Unix()
}
// NOTE: This is a GORM query hook.
func (s *LoginSource) AfterFind() error {
s.Created = time.Unix(s.CreatedUnix, 0).Local()
s.Updated = time.Unix(s.UpdatedUnix, 0).Local()
switch s.Type {
case LoginLDAP, LoginDLDAP:
s.Config = new(LDAPConfig)
case LoginSMTP:
s.Config = new(SMTPConfig)
case LoginPAM:
s.Config = new(PAMConfig)
case LoginGitHub:
s.Config = new(GitHubConfig)
default:
return fmt.Errorf("unrecognized login source type: %v", s.Type)
}
return jsoniter.UnmarshalFromString(s.RawConfig, s.Config)
}
func (s *LoginSource) TypeName() string {
return LoginNames[s.Type]
}
func (s *LoginSource) IsLDAP() bool {
return s.Type == LoginLDAP
}
func (s *LoginSource) IsDLDAP() bool {
return s.Type == LoginDLDAP
}
func (s *LoginSource) IsSMTP() bool {
return s.Type == LoginSMTP
}
func (s *LoginSource) IsPAM() bool {
return s.Type == LoginPAM
}
func (s *LoginSource) IsGitHub() bool {
return s.Type == LoginGitHub
}
func (s *LoginSource) HasTLS() bool {
return ((s.IsLDAP() || s.IsDLDAP()) &&
s.LDAP().SecurityProtocol > ldap.SecurityProtocolUnencrypted) ||
s.IsSMTP()
}
func (s *LoginSource) UseTLS() bool {
switch s.Type {
case LoginLDAP, LoginDLDAP:
return s.LDAP().SecurityProtocol != ldap.SecurityProtocolUnencrypted
case LoginSMTP:
return s.SMTP().TLS
}
return false
}
func (s *LoginSource) SkipVerify() bool {
switch s.Type {
case LoginLDAP, LoginDLDAP:
return s.LDAP().SkipVerify
case LoginSMTP:
return s.SMTP().SkipVerify
}
return false
}
func (s *LoginSource) LDAP() *LDAPConfig {
return s.Config.(*LDAPConfig)
}
func (s *LoginSource) SMTP() *SMTPConfig {
return s.Config.(*SMTPConfig)
}
func (s *LoginSource) PAM() *PAMConfig {
return s.Config.(*PAMConfig)
}
func (s *LoginSource) GitHub() *GitHubConfig {
return s.Config.(*GitHubConfig)
}
var _ LoginSourcesStore = (*loginSources)(nil)
type loginSources struct {
*gorm.DB
files loginSourceFilesStore
}
type CreateLoginSourceOpts struct {
Type LoginType
Name string
Activated bool
Default bool
Config interface{}
}
type ErrLoginSourceAlreadyExist struct {
args errutil.Args
}
func IsErrLoginSourceAlreadyExist(err error) bool {
_, ok := err.(ErrLoginSourceAlreadyExist)
return ok
}
func (err ErrLoginSourceAlreadyExist) Error() string {
return fmt.Sprintf("login source already exists: %v", err.args)
}
func (db *loginSources) Create(opts CreateLoginSourceOpts) (*LoginSource, error) {
err := db.Where("name = ?", opts.Name).First(new(LoginSource)).Error
if err == nil {
return nil, ErrLoginSourceAlreadyExist{args: errutil.Args{"name": opts.Name}}
} else if !gorm.IsRecordNotFoundError(err) {
return nil, err
}
source := &LoginSource{
Type: opts.Type,
Name: opts.Name,
IsActived: opts.Activated,
IsDefault: opts.Default,
Config: opts.Config,
}
return source, db.DB.Create(source).Error
}
func (db *loginSources) Count() int64 {
var count int64
db.Model(new(LoginSource)).Count(&count)
return count + int64(db.files.Len())
}
type ErrLoginSourceInUse struct {
args errutil.Args
}
func IsErrLoginSourceInUse(err error) bool {
_, ok := err.(ErrLoginSourceInUse)
return ok
}
func (err ErrLoginSourceInUse) Error() string {
return fmt.Sprintf("login source is still used by some users: %v", err.args)
}
func (db *loginSources) DeleteByID(id int64) error {
var count int64
err := db.Model(new(User)).Where("login_source = ?", id).Count(&count).Error
if err != nil {
return err
} else if count > 0 {
return ErrLoginSourceInUse{args: errutil.Args{"id": id}}
}
return db.Where("id = ?", id).Delete(new(LoginSource)).Error
}
func (db *loginSources) GetByID(id int64) (*LoginSource, error) {
source := new(LoginSource)
err := db.Where("id = ?", id).First(source).Error
if err != nil {
if gorm.IsRecordNotFoundError(err) {
return db.files.GetByID(id)
}
return nil, err
}
return source, nil
}
type ListLoginSourceOpts struct {
// Whether to only include activated login sources.
OnlyActivated bool
}
func (db *loginSources) List(opts ListLoginSourceOpts) ([]*LoginSource, error) {
var sources []*LoginSource
query := db.Order("id ASC")
if opts.OnlyActivated {
query = query.Where("is_actived = ?", true)
}
err := query.Find(&sources).Error
if err != nil {
return nil, err
}
return append(sources, db.files.List(opts)...), nil
}
func (db *loginSources) ResetNonDefault(dflt *LoginSource) error {
err := db.Model(new(LoginSource)).Where("id != ?", dflt.ID).Updates(map[string]interface{}{"is_default": false}).Error
if err != nil {
return err
}
for _, source := range db.files.List(ListLoginSourceOpts{}) {
if source.File != nil && source.ID != dflt.ID {
source.File.SetGeneral("is_default", "false")
if err = source.File.Save(); err != nil {
return errors.Wrap(err, "save file")
}
}
}
db.files.Update(dflt)
return nil
}
func (db *loginSources) Save(source *LoginSource) error {
if source.File == nil {
return db.DB.Save(source).Error
}
source.File.SetGeneral("name", source.Name)
source.File.SetGeneral("is_activated", strconv.FormatBool(source.IsActived))
source.File.SetGeneral("is_default", strconv.FormatBool(source.IsDefault))
if err := source.File.SetConfig(source.Config); err != nil {
return errors.Wrap(err, "set config")
} else if err = source.File.Save(); err != nil {
return errors.Wrap(err, "save file")
}
return nil
}