all: unwrap `database.LoginSourcesStore` interface (#7694)

pull/7695/head
Joe Chen 2024-03-17 20:14:54 -04:00 committed by GitHub
parent 3a5132b6f7
commit e634aa6277
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 798 additions and 1672 deletions

View File

@ -23,6 +23,8 @@ const (
PAM // 4
DLDAP // 5
GitHub // 6
Mock Type = 999
)
// Name returns the human-readable name for given authentication type.
@ -45,8 +47,7 @@ type ErrBadCredentials struct {
// IsErrBadCredentials returns true if the underlying error has the type
// ErrBadCredentials.
func IsErrBadCredentials(err error) bool {
_, ok := errors.Cause(err).(ErrBadCredentials)
return ok
return errors.As(err, &ErrBadCredentials{})
}
func (err ErrBadCredentials) Error() string {

View File

@ -117,13 +117,12 @@ func NewConnection(w logger.Writer) (*gorm.DB, error) {
log.Trace("Auto migrated %q", name)
}
sourceFiles, err := loadLoginSourceFiles(filepath.Join(conf.CustomDir(), "conf", "auth.d"), db.NowFunc)
loadedLoginSourceFilesStore, err = loadLoginSourceFiles(filepath.Join(conf.CustomDir(), "conf", "auth.d"), db.NowFunc)
if err != nil {
return nil, errors.Wrap(err, "load login source files")
}
// Initialize stores, sorted in alphabetical order.
LoginSources = &loginSourcesStore{DB: db, files: sourceFiles}
Notices = NewNoticesStore(db)
Orgs = NewOrgsStore(db)
Perms = NewPermsStore(db)
@ -166,3 +165,11 @@ func (db *DB) Actions() *ActionsStore {
func (db *DB) LFS() *LFSStore {
return newLFSStore(db.db)
}
// NOTE: It is not guarded by a mutex because it only gets written during the
// service start.
var loadedLoginSourceFilesStore loginSourceFilesStore
func (db *DB) LoginSources() *LoginSourcesStore {
return newLoginSourcesStore(db.db, loadedLoginSourceFilesStore)
}

View File

@ -52,8 +52,7 @@ type ErrLoginSourceNotExist struct {
}
func IsErrLoginSourceNotExist(err error) bool {
_, ok := err.(ErrLoginSourceNotExist)
return ok
return errors.As(err, &ErrLoginSourceNotExist{})
}
func (err ErrLoginSourceNotExist) Error() string {

View File

@ -22,30 +22,6 @@ import (
"gogs.io/gogs/internal/errutil"
)
// LoginSourcesStore is the persistent interface for login sources.
type LoginSourcesStore interface {
// Create creates a new login source and persist to database. It returns
// ErrLoginSourceAlreadyExist when a login source with same name already exists.
Create(ctx context.Context, opts CreateLoginSourceOptions) (*LoginSource, error)
// Count returns the total number of login sources.
Count(ctx context.Context) int64
// DeleteByID deletes a login source by given ID. It returns ErrLoginSourceInUse
// if at least one user is associated with the login source.
DeleteByID(ctx context.Context, id int64) error
// GetByID returns the login source with given ID. It returns
// ErrLoginSourceNotExist when not found.
GetByID(ctx context.Context, id int64) (*LoginSource, error)
// List returns a list of login sources filtered by options.
List(ctx context.Context, opts ListLoginSourceOptions) ([]*LoginSource, error)
// ResetNonDefault clears default flag for all the other login sources.
ResetNonDefault(ctx context.Context, source *LoginSource) error
// Save persists all values of given login source to database or local file. The
// Updated field is set to current time automatically.
Save(ctx context.Context, t *LoginSource) error
}
var LoginSources LoginSourcesStore
// LoginSource represents an external way for authorizing users.
type LoginSource struct {
ID int64 `gorm:"primaryKey"`
@ -88,6 +64,10 @@ func (s *LoginSource) BeforeUpdate(tx *gorm.DB) error {
return nil
}
type mockProviderConfig struct {
ExternalAccount *auth.ExternalAccount
}
// AfterFind implements the GORM query hook.
func (s *LoginSource) AfterFind(_ *gorm.DB) error {
s.Created = time.Unix(s.CreatedUnix, 0).Local()
@ -134,6 +114,16 @@ func (s *LoginSource) AfterFind(_ *gorm.DB) error {
}
s.Provider = github.NewProvider(&cfg)
case auth.Mock:
var cfg mockProviderConfig
err := jsoniter.UnmarshalFromString(s.Config, &cfg)
if err != nil {
return err
}
mockProvider := NewMockProvider()
mockProvider.AuthenticateFunc.SetDefaultReturn(cfg.ExternalAccount, nil)
s.Provider = mockProvider
default:
return fmt.Errorf("unrecognized login source type: %v", s.Type)
}
@ -180,13 +170,19 @@ func (s *LoginSource) GitHub() *github.Config {
return s.Provider.Config().(*github.Config)
}
var _ LoginSourcesStore = (*loginSourcesStore)(nil)
type loginSourcesStore struct {
*gorm.DB
// LoginSourcesStore is the storage layer for login sources.
type LoginSourcesStore struct {
db *gorm.DB
files loginSourceFilesStore
}
func newLoginSourcesStore(db *gorm.DB, files loginSourceFilesStore) *LoginSourcesStore {
return &LoginSourcesStore{
db: db,
files: files,
}
}
type CreateLoginSourceOptions struct {
Type auth.Type
Name string
@ -200,19 +196,20 @@ type ErrLoginSourceAlreadyExist struct {
}
func IsErrLoginSourceAlreadyExist(err error) bool {
_, ok := err.(ErrLoginSourceAlreadyExist)
return ok
return errors.As(err, &ErrLoginSourceAlreadyExist{})
}
func (err ErrLoginSourceAlreadyExist) Error() string {
return fmt.Sprintf("login source already exists: %v", err.args)
}
func (s *loginSourcesStore) Create(ctx context.Context, opts CreateLoginSourceOptions) (*LoginSource, error) {
err := s.WithContext(ctx).Where("name = ?", opts.Name).First(new(LoginSource)).Error
// Create creates a new login source and persists it to the database. It returns
// ErrLoginSourceAlreadyExist when a login source with same name already exists.
func (s *LoginSourcesStore) Create(ctx context.Context, opts CreateLoginSourceOptions) (*LoginSource, error) {
err := s.db.WithContext(ctx).Where("name = ?", opts.Name).First(new(LoginSource)).Error
if err == nil {
return nil, ErrLoginSourceAlreadyExist{args: errutil.Args{"name": opts.Name}}
} else if err != gorm.ErrRecordNotFound {
} else if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
@ -226,12 +223,13 @@ func (s *loginSourcesStore) Create(ctx context.Context, opts CreateLoginSourceOp
if err != nil {
return nil, err
}
return source, s.WithContext(ctx).Create(source).Error
return source, s.db.WithContext(ctx).Create(source).Error
}
func (s *loginSourcesStore) Count(ctx context.Context) int64 {
// Count returns the total number of login sources.
func (s *LoginSourcesStore) Count(ctx context.Context) int64 {
var count int64
s.WithContext(ctx).Model(new(LoginSource)).Count(&count)
s.db.WithContext(ctx).Model(new(LoginSource)).Count(&count)
return count + int64(s.files.Len())
}
@ -240,31 +238,34 @@ type ErrLoginSourceInUse struct {
}
func IsErrLoginSourceInUse(err error) bool {
_, ok := err.(ErrLoginSourceInUse)
return ok
return errors.As(err, &ErrLoginSourceInUse{})
}
func (err ErrLoginSourceInUse) Error() string {
return fmt.Sprintf("login source is still used by some users: %v", err.args)
}
func (s *loginSourcesStore) DeleteByID(ctx context.Context, id int64) error {
// DeleteByID deletes a login source by given ID. It returns ErrLoginSourceInUse
// if at least one user is associated with the login source.
func (s *LoginSourcesStore) DeleteByID(ctx context.Context, id int64) error {
var count int64
err := s.WithContext(ctx).Model(new(User)).Where("login_source = ?", id).Count(&count).Error
err := s.db.WithContext(ctx).Model(new(User)).Where("login_source = ?", id).Count(&count).Error
if err != nil {
return err
} else if count > 0 {
return ErrLoginSourceInUse{args: errutil.Args{"id": id}}
}
return s.WithContext(ctx).Where("id = ?", id).Delete(new(LoginSource)).Error
return s.db.WithContext(ctx).Where("id = ?", id).Delete(new(LoginSource)).Error
}
func (s *loginSourcesStore) GetByID(ctx context.Context, id int64) (*LoginSource, error) {
// GetByID returns the login source with given ID. It returns
// ErrLoginSourceNotExist when not found.
func (s *LoginSourcesStore) GetByID(ctx context.Context, id int64) (*LoginSource, error) {
source := new(LoginSource)
err := s.WithContext(ctx).Where("id = ?", id).First(source).Error
err := s.db.WithContext(ctx).Where("id = ?", id).First(source).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
return s.files.GetByID(id)
}
return nil, err
@ -277,9 +278,10 @@ type ListLoginSourceOptions struct {
OnlyActivated bool
}
func (s *loginSourcesStore) List(ctx context.Context, opts ListLoginSourceOptions) ([]*LoginSource, error) {
// List returns a list of login sources filtered by options.
func (s *LoginSourcesStore) List(ctx context.Context, opts ListLoginSourceOptions) ([]*LoginSource, error) {
var sources []*LoginSource
query := s.WithContext(ctx).Order("id ASC")
query := s.db.WithContext(ctx).Order("id ASC")
if opts.OnlyActivated {
query = query.Where("is_actived = ?", true)
}
@ -291,8 +293,9 @@ func (s *loginSourcesStore) List(ctx context.Context, opts ListLoginSourceOption
return append(sources, s.files.List(opts)...), nil
}
func (s *loginSourcesStore) ResetNonDefault(ctx context.Context, dflt *LoginSource) error {
err := s.WithContext(ctx).
// ResetNonDefault clears default flag for all the other login sources.
func (s *LoginSourcesStore) ResetNonDefault(ctx context.Context, dflt *LoginSource) error {
err := s.db.WithContext(ctx).
Model(new(LoginSource)).
Where("id != ?", dflt.ID).
Updates(map[string]any{"is_default": false}).
@ -314,9 +317,11 @@ func (s *loginSourcesStore) ResetNonDefault(ctx context.Context, dflt *LoginSour
return nil
}
func (s *loginSourcesStore) Save(ctx context.Context, source *LoginSource) error {
// Save persists all values of given login source to database or local file. The
// Updated field is set to current time automatically.
func (s *LoginSourcesStore) Save(ctx context.Context, source *LoginSource) error {
if source.File == nil {
return s.WithContext(ctx).Save(source).Error
return s.db.WithContext(ctx).Save(source).Error
}
source.File.SetGeneral("name", source.Name)

View File

@ -163,13 +163,13 @@ func TestLoginSources(t *testing.T) {
t.Parallel()
ctx := context.Background()
db := &loginSourcesStore{
DB: newTestDB(t, "loginSourcesStore"),
s := &LoginSourcesStore{
db: newTestDB(t, "LoginSourcesStore"),
}
for _, tc := range []struct {
name string
test func(t *testing.T, ctx context.Context, db *loginSourcesStore)
test func(t *testing.T, ctx context.Context, s *LoginSourcesStore)
}{
{"Create", loginSourcesCreate},
{"Count", loginSourcesCount},
@ -181,10 +181,10 @@ func TestLoginSources(t *testing.T) {
} {
t.Run(tc.name, func(t *testing.T) {
t.Cleanup(func() {
err := clearTables(t, db.DB)
err := clearTables(t, s.db)
require.NoError(t, err)
})
tc.test(t, ctx, db)
tc.test(t, ctx, s)
})
if t.Failed() {
break
@ -192,9 +192,9 @@ func TestLoginSources(t *testing.T) {
}
}
func loginSourcesCreate(t *testing.T, ctx context.Context, db *loginSourcesStore) {
func loginSourcesCreate(t *testing.T, ctx context.Context, s *LoginSourcesStore) {
// Create first login source with name "GitHub"
source, err := db.Create(ctx,
source, err := s.Create(ctx,
CreateLoginSourceOptions{
Type: auth.GitHub,
Name: "GitHub",
@ -208,20 +208,28 @@ func loginSourcesCreate(t *testing.T, ctx context.Context, db *loginSourcesStore
require.NoError(t, err)
// Get it back and check the Created field
source, err = db.GetByID(ctx, source.ID)
source, err = s.GetByID(ctx, source.ID)
require.NoError(t, err)
assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Created.UTC().Format(time.RFC3339))
assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Updated.UTC().Format(time.RFC3339))
assert.Equal(t, s.db.NowFunc().Format(time.RFC3339), source.Created.UTC().Format(time.RFC3339))
assert.Equal(t, s.db.NowFunc().Format(time.RFC3339), source.Updated.UTC().Format(time.RFC3339))
// Try create second login source with same name should fail
_, err = db.Create(ctx, CreateLoginSourceOptions{Name: source.Name})
// Try to create second login source with same name should fail.
_, err = s.Create(ctx, CreateLoginSourceOptions{Name: source.Name})
wantErr := ErrLoginSourceAlreadyExist{args: errutil.Args{"name": source.Name}}
assert.Equal(t, wantErr, err)
}
func loginSourcesCount(t *testing.T, ctx context.Context, db *loginSourcesStore) {
func setMockLoginSourceFilesStore(t *testing.T, s *LoginSourcesStore, mock loginSourceFilesStore) {
before := s.files
s.files = mock
t.Cleanup(func() {
s.files = before
})
}
func loginSourcesCount(t *testing.T, ctx context.Context, s *LoginSourcesStore) {
// Create two login sources, one in database and one as source file.
_, err := db.Create(ctx,
_, err := s.Create(ctx,
CreateLoginSourceOptions{
Type: auth.GitHub,
Name: "GitHub",
@ -236,14 +244,14 @@ func loginSourcesCount(t *testing.T, ctx context.Context, db *loginSourcesStore)
mock := NewMockLoginSourceFilesStore()
mock.LenFunc.SetDefaultReturn(2)
setMockLoginSourceFilesStore(t, db, mock)
setMockLoginSourceFilesStore(t, s, mock)
assert.Equal(t, int64(3), db.Count(ctx))
assert.Equal(t, int64(3), s.Count(ctx))
}
func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSourcesStore) {
func loginSourcesDeleteByID(t *testing.T, ctx context.Context, s *LoginSourcesStore) {
t.Run("delete but in used", func(t *testing.T) {
source, err := db.Create(ctx,
source, err := s.Create(ctx,
CreateLoginSourceOptions{
Type: auth.GitHub,
Name: "GitHub",
@ -257,7 +265,7 @@ func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSourcesS
require.NoError(t, err)
// Create a user that uses this login source
_, err = (&usersStore{DB: db.DB}).Create(ctx, "alice", "",
_, err = NewUsersStore(s.db).Create(ctx, "alice", "",
CreateUserOptions{
LoginSource: source.ID,
},
@ -265,7 +273,7 @@ func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSourcesS
require.NoError(t, err)
// Delete the login source will result in error
err = db.DeleteByID(ctx, source.ID)
err = s.DeleteByID(ctx, source.ID)
wantErr := ErrLoginSourceInUse{args: errutil.Args{"id": source.ID}}
assert.Equal(t, wantErr, err)
})
@ -274,10 +282,10 @@ func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSourcesS
mock.GetByIDFunc.SetDefaultHook(func(id int64) (*LoginSource, error) {
return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
})
setMockLoginSourceFilesStore(t, db, mock)
setMockLoginSourceFilesStore(t, s, mock)
// Create a login source with name "GitHub2"
source, err := db.Create(ctx,
source, err := s.Create(ctx,
CreateLoginSourceOptions{
Type: auth.GitHub,
Name: "GitHub2",
@ -291,24 +299,24 @@ func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSourcesS
require.NoError(t, err)
// Delete a non-existent ID is noop
err = db.DeleteByID(ctx, 9999)
err = s.DeleteByID(ctx, 9999)
require.NoError(t, err)
// We should be able to get it back
_, err = db.GetByID(ctx, source.ID)
_, err = s.GetByID(ctx, source.ID)
require.NoError(t, err)
// Now delete this login source with ID
err = db.DeleteByID(ctx, source.ID)
err = s.DeleteByID(ctx, source.ID)
require.NoError(t, err)
// We should get token not found error
_, err = db.GetByID(ctx, source.ID)
_, err = s.GetByID(ctx, source.ID)
wantErr := ErrLoginSourceNotExist{args: errutil.Args{"id": source.ID}}
assert.Equal(t, wantErr, err)
}
func loginSourcesGetByID(t *testing.T, ctx context.Context, db *loginSourcesStore) {
func loginSourcesGetByID(t *testing.T, ctx context.Context, s *LoginSourcesStore) {
mock := NewMockLoginSourceFilesStore()
mock.GetByIDFunc.SetDefaultHook(func(id int64) (*LoginSource, error) {
if id != 101 {
@ -316,14 +324,14 @@ func loginSourcesGetByID(t *testing.T, ctx context.Context, db *loginSourcesStor
}
return &LoginSource{ID: id}, nil
})
setMockLoginSourceFilesStore(t, db, mock)
setMockLoginSourceFilesStore(t, s, mock)
expConfig := &github.Config{
APIEndpoint: "https://api.github.com",
}
// Create a login source with name "GitHub"
source, err := db.Create(ctx,
source, err := s.Create(ctx,
CreateLoginSourceOptions{
Type: auth.GitHub,
Name: "GitHub",
@ -335,16 +343,16 @@ func loginSourcesGetByID(t *testing.T, ctx context.Context, db *loginSourcesStor
require.NoError(t, err)
// Get the one in the database and test the read/write hooks
source, err = db.GetByID(ctx, source.ID)
source, err = s.GetByID(ctx, source.ID)
require.NoError(t, err)
assert.Equal(t, expConfig, source.Provider.Config())
// Get the one in source file store
_, err = db.GetByID(ctx, 101)
_, err = s.GetByID(ctx, 101)
require.NoError(t, err)
}
func loginSourcesList(t *testing.T, ctx context.Context, db *loginSourcesStore) {
func loginSourcesList(t *testing.T, ctx context.Context, s *LoginSourcesStore) {
mock := NewMockLoginSourceFilesStore()
mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource {
if opts.OnlyActivated {
@ -357,10 +365,10 @@ func loginSourcesList(t *testing.T, ctx context.Context, db *loginSourcesStore)
{ID: 2},
}
})
setMockLoginSourceFilesStore(t, db, mock)
setMockLoginSourceFilesStore(t, s, mock)
// Create two login sources in database, one activated and the other one not
_, err := db.Create(ctx,
_, err := s.Create(ctx,
CreateLoginSourceOptions{
Type: auth.PAM,
Name: "PAM",
@ -370,7 +378,7 @@ func loginSourcesList(t *testing.T, ctx context.Context, db *loginSourcesStore)
},
)
require.NoError(t, err)
_, err = db.Create(ctx,
_, err = s.Create(ctx,
CreateLoginSourceOptions{
Type: auth.GitHub,
Name: "GitHub",
@ -383,17 +391,17 @@ func loginSourcesList(t *testing.T, ctx context.Context, db *loginSourcesStore)
require.NoError(t, err)
// List all login sources
sources, err := db.List(ctx, ListLoginSourceOptions{})
sources, err := s.List(ctx, ListLoginSourceOptions{})
require.NoError(t, err)
assert.Equal(t, 4, len(sources), "number of sources")
// Only list activated login sources
sources, err = db.List(ctx, ListLoginSourceOptions{OnlyActivated: true})
sources, err = s.List(ctx, ListLoginSourceOptions{OnlyActivated: true})
require.NoError(t, err)
assert.Equal(t, 2, len(sources), "number of sources")
}
func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSourcesStore) {
func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, s *LoginSourcesStore) {
mock := NewMockLoginSourceFilesStore()
mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource {
mockFile := NewMockLoginSourceFileStore()
@ -407,10 +415,10 @@ func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSou
},
}
})
setMockLoginSourceFilesStore(t, db, mock)
setMockLoginSourceFilesStore(t, s, mock)
// Create two login sources both have default on
source1, err := db.Create(ctx,
source1, err := s.Create(ctx,
CreateLoginSourceOptions{
Type: auth.PAM,
Name: "PAM",
@ -421,7 +429,7 @@ func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSou
},
)
require.NoError(t, err)
source2, err := db.Create(ctx,
source2, err := s.Create(ctx,
CreateLoginSourceOptions{
Type: auth.GitHub,
Name: "GitHub",
@ -435,23 +443,23 @@ func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSou
require.NoError(t, err)
// Set source 1 as default
err = db.ResetNonDefault(ctx, source1)
err = s.ResetNonDefault(ctx, source1)
require.NoError(t, err)
// Verify the default state
source1, err = db.GetByID(ctx, source1.ID)
source1, err = s.GetByID(ctx, source1.ID)
require.NoError(t, err)
assert.True(t, source1.IsDefault)
source2, err = db.GetByID(ctx, source2.ID)
source2, err = s.GetByID(ctx, source2.ID)
require.NoError(t, err)
assert.False(t, source2.IsDefault)
}
func loginSourcesSave(t *testing.T, ctx context.Context, db *loginSourcesStore) {
func loginSourcesSave(t *testing.T, ctx context.Context, s *LoginSourcesStore) {
t.Run("save to database", func(t *testing.T) {
// Create a login source with name "GitHub"
source, err := db.Create(ctx,
source, err := s.Create(ctx,
CreateLoginSourceOptions{
Type: auth.GitHub,
Name: "GitHub",
@ -468,10 +476,10 @@ func loginSourcesSave(t *testing.T, ctx context.Context, db *loginSourcesStore)
source.Provider = github.NewProvider(&github.Config{
APIEndpoint: "https://api2.github.com",
})
err = db.Save(ctx, source)
err = s.Save(ctx, source)
require.NoError(t, err)
source, err = db.GetByID(ctx, source.ID)
source, err = s.GetByID(ctx, source.ID)
require.NoError(t, err)
assert.False(t, source.IsActived)
assert.Equal(t, "https://api2.github.com", source.GitHub().APIEndpoint)
@ -485,7 +493,7 @@ func loginSourcesSave(t *testing.T, ctx context.Context, db *loginSourcesStore)
}),
File: mockFile,
}
err := db.Save(ctx, source)
err := s.Save(ctx, source)
require.NoError(t, err)
mockrequire.Called(t, mockFile.SaveFunc)
})

View File

@ -8,22 +8,6 @@ import (
"testing"
)
func setMockLoginSourcesStore(t *testing.T, mock LoginSourcesStore) {
before := LoginSources
LoginSources = mock
t.Cleanup(func() {
LoginSources = before
})
}
func setMockLoginSourceFilesStore(t *testing.T, db *loginSourcesStore, mock loginSourceFilesStore) {
before := db.files
db.files = mock
t.Cleanup(func() {
db.files = before
})
}
func SetMockPermsStore(t *testing.T, mock PermsStore) {
before := Perms
Perms = mock

View File

@ -0,0 +1,620 @@
// Code generated by go-mockgen 1.3.7; DO NOT EDIT.
//
// This file was generated by running `go-mockgen` at the root of this repository.
// To add additional mocks to this or another package, add a new entry to the
// mockgen.yaml file in the root of this repository.
package database
import (
"sync"
auth "gogs.io/gogs/internal/auth"
)
// MockProvider is a mock implementation of the Provider interface (from the
// package gogs.io/gogs/internal/auth) used for unit testing.
type MockProvider struct {
// AuthenticateFunc is an instance of a mock function object controlling
// the behavior of the method Authenticate.
AuthenticateFunc *ProviderAuthenticateFunc
// ConfigFunc is an instance of a mock function object controlling the
// behavior of the method Config.
ConfigFunc *ProviderConfigFunc
// HasTLSFunc is an instance of a mock function object controlling the
// behavior of the method HasTLS.
HasTLSFunc *ProviderHasTLSFunc
// SkipTLSVerifyFunc is an instance of a mock function object
// controlling the behavior of the method SkipTLSVerify.
SkipTLSVerifyFunc *ProviderSkipTLSVerifyFunc
// UseTLSFunc is an instance of a mock function object controlling the
// behavior of the method UseTLS.
UseTLSFunc *ProviderUseTLSFunc
}
// NewMockProvider creates a new mock of the Provider interface. All methods
// return zero values for all results, unless overwritten.
func NewMockProvider() *MockProvider {
return &MockProvider{
AuthenticateFunc: &ProviderAuthenticateFunc{
defaultHook: func(string, string) (r0 *auth.ExternalAccount, r1 error) {
return
},
},
ConfigFunc: &ProviderConfigFunc{
defaultHook: func() (r0 interface{}) {
return
},
},
HasTLSFunc: &ProviderHasTLSFunc{
defaultHook: func() (r0 bool) {
return
},
},
SkipTLSVerifyFunc: &ProviderSkipTLSVerifyFunc{
defaultHook: func() (r0 bool) {
return
},
},
UseTLSFunc: &ProviderUseTLSFunc{
defaultHook: func() (r0 bool) {
return
},
},
}
}
// NewStrictMockProvider creates a new mock of the Provider interface. All
// methods panic on invocation, unless overwritten.
func NewStrictMockProvider() *MockProvider {
return &MockProvider{
AuthenticateFunc: &ProviderAuthenticateFunc{
defaultHook: func(string, string) (*auth.ExternalAccount, error) {
panic("unexpected invocation of MockProvider.Authenticate")
},
},
ConfigFunc: &ProviderConfigFunc{
defaultHook: func() interface{} {
panic("unexpected invocation of MockProvider.Config")
},
},
HasTLSFunc: &ProviderHasTLSFunc{
defaultHook: func() bool {
panic("unexpected invocation of MockProvider.HasTLS")
},
},
SkipTLSVerifyFunc: &ProviderSkipTLSVerifyFunc{
defaultHook: func() bool {
panic("unexpected invocation of MockProvider.SkipTLSVerify")
},
},
UseTLSFunc: &ProviderUseTLSFunc{
defaultHook: func() bool {
panic("unexpected invocation of MockProvider.UseTLS")
},
},
}
}
// NewMockProviderFrom creates a new mock of the MockProvider interface. All
// methods delegate to the given implementation, unless overwritten.
func NewMockProviderFrom(i auth.Provider) *MockProvider {
return &MockProvider{
AuthenticateFunc: &ProviderAuthenticateFunc{
defaultHook: i.Authenticate,
},
ConfigFunc: &ProviderConfigFunc{
defaultHook: i.Config,
},
HasTLSFunc: &ProviderHasTLSFunc{
defaultHook: i.HasTLS,
},
SkipTLSVerifyFunc: &ProviderSkipTLSVerifyFunc{
defaultHook: i.SkipTLSVerify,
},
UseTLSFunc: &ProviderUseTLSFunc{
defaultHook: i.UseTLS,
},
}
}
// ProviderAuthenticateFunc describes the behavior when the Authenticate
// method of the parent MockProvider instance is invoked.
type ProviderAuthenticateFunc struct {
defaultHook func(string, string) (*auth.ExternalAccount, error)
hooks []func(string, string) (*auth.ExternalAccount, error)
history []ProviderAuthenticateFuncCall
mutex sync.Mutex
}
// Authenticate delegates to the next hook function in the queue and stores
// the parameter and result values of this invocation.
func (m *MockProvider) Authenticate(v0 string, v1 string) (*auth.ExternalAccount, error) {
r0, r1 := m.AuthenticateFunc.nextHook()(v0, v1)
m.AuthenticateFunc.appendCall(ProviderAuthenticateFuncCall{v0, v1, r0, r1})
return r0, r1
}
// SetDefaultHook sets function that is called when the Authenticate method
// of the parent MockProvider instance is invoked and the hook queue is
// empty.
func (f *ProviderAuthenticateFunc) SetDefaultHook(hook func(string, string) (*auth.ExternalAccount, error)) {
f.defaultHook = hook
}
// PushHook adds a function to the end of hook queue. Each invocation of the
// Authenticate method of the parent MockProvider instance invokes the hook
// at the front of the queue and discards it. After the queue is empty, the
// default hook function is invoked for any future action.
func (f *ProviderAuthenticateFunc) PushHook(hook func(string, string) (*auth.ExternalAccount, error)) {
f.mutex.Lock()
f.hooks = append(f.hooks, hook)
f.mutex.Unlock()
}
// SetDefaultReturn calls SetDefaultHook with a function that returns the
// given values.
func (f *ProviderAuthenticateFunc) SetDefaultReturn(r0 *auth.ExternalAccount, r1 error) {
f.SetDefaultHook(func(string, string) (*auth.ExternalAccount, error) {
return r0, r1
})
}
// PushReturn calls PushHook with a function that returns the given values.
func (f *ProviderAuthenticateFunc) PushReturn(r0 *auth.ExternalAccount, r1 error) {
f.PushHook(func(string, string) (*auth.ExternalAccount, error) {
return r0, r1
})
}
func (f *ProviderAuthenticateFunc) nextHook() func(string, string) (*auth.ExternalAccount, error) {
f.mutex.Lock()
defer f.mutex.Unlock()
if len(f.hooks) == 0 {
return f.defaultHook
}
hook := f.hooks[0]
f.hooks = f.hooks[1:]
return hook
}
func (f *ProviderAuthenticateFunc) appendCall(r0 ProviderAuthenticateFuncCall) {
f.mutex.Lock()
f.history = append(f.history, r0)
f.mutex.Unlock()
}
// History returns a sequence of ProviderAuthenticateFuncCall objects
// describing the invocations of this function.
func (f *ProviderAuthenticateFunc) History() []ProviderAuthenticateFuncCall {
f.mutex.Lock()
history := make([]ProviderAuthenticateFuncCall, len(f.history))
copy(history, f.history)
f.mutex.Unlock()
return history
}
// ProviderAuthenticateFuncCall is an object that describes an invocation of
// method Authenticate on an instance of MockProvider.
type ProviderAuthenticateFuncCall struct {
// Arg0 is the value of the 1st argument passed to this method
// invocation.
Arg0 string
// Arg1 is the value of the 2nd argument passed to this method
// invocation.
Arg1 string
// Result0 is the value of the 1st result returned from this method
// invocation.
Result0 *auth.ExternalAccount
// Result1 is the value of the 2nd result returned from this method
// invocation.
Result1 error
}
// Args returns an interface slice containing the arguments of this
// invocation.
func (c ProviderAuthenticateFuncCall) Args() []interface{} {
return []interface{}{c.Arg0, c.Arg1}
}
// Results returns an interface slice containing the results of this
// invocation.
func (c ProviderAuthenticateFuncCall) Results() []interface{} {
return []interface{}{c.Result0, c.Result1}
}
// ProviderConfigFunc describes the behavior when the Config method of the
// parent MockProvider instance is invoked.
type ProviderConfigFunc struct {
defaultHook func() interface{}
hooks []func() interface{}
history []ProviderConfigFuncCall
mutex sync.Mutex
}
// Config delegates to the next hook function in the queue and stores the
// parameter and result values of this invocation.
func (m *MockProvider) Config() interface{} {
r0 := m.ConfigFunc.nextHook()()
m.ConfigFunc.appendCall(ProviderConfigFuncCall{r0})
return r0
}
// SetDefaultHook sets function that is called when the Config method of the
// parent MockProvider instance is invoked and the hook queue is empty.
func (f *ProviderConfigFunc) SetDefaultHook(hook func() interface{}) {
f.defaultHook = hook
}
// PushHook adds a function to the end of hook queue. Each invocation of the
// Config method of the parent MockProvider instance invokes the hook at the
// front of the queue and discards it. After the queue is empty, the default
// hook function is invoked for any future action.
func (f *ProviderConfigFunc) PushHook(hook func() interface{}) {
f.mutex.Lock()
f.hooks = append(f.hooks, hook)
f.mutex.Unlock()
}
// SetDefaultReturn calls SetDefaultHook with a function that returns the
// given values.
func (f *ProviderConfigFunc) SetDefaultReturn(r0 interface{}) {
f.SetDefaultHook(func() interface{} {
return r0
})
}
// PushReturn calls PushHook with a function that returns the given values.
func (f *ProviderConfigFunc) PushReturn(r0 interface{}) {
f.PushHook(func() interface{} {
return r0
})
}
func (f *ProviderConfigFunc) nextHook() func() interface{} {
f.mutex.Lock()
defer f.mutex.Unlock()
if len(f.hooks) == 0 {
return f.defaultHook
}
hook := f.hooks[0]
f.hooks = f.hooks[1:]
return hook
}
func (f *ProviderConfigFunc) appendCall(r0 ProviderConfigFuncCall) {
f.mutex.Lock()
f.history = append(f.history, r0)
f.mutex.Unlock()
}
// History returns a sequence of ProviderConfigFuncCall objects describing
// the invocations of this function.
func (f *ProviderConfigFunc) History() []ProviderConfigFuncCall {
f.mutex.Lock()
history := make([]ProviderConfigFuncCall, len(f.history))
copy(history, f.history)
f.mutex.Unlock()
return history
}
// ProviderConfigFuncCall is an object that describes an invocation of
// method Config on an instance of MockProvider.
type ProviderConfigFuncCall struct {
// Result0 is the value of the 1st result returned from this method
// invocation.
Result0 interface{}
}
// Args returns an interface slice containing the arguments of this
// invocation.
func (c ProviderConfigFuncCall) Args() []interface{} {
return []interface{}{}
}
// Results returns an interface slice containing the results of this
// invocation.
func (c ProviderConfigFuncCall) Results() []interface{} {
return []interface{}{c.Result0}
}
// ProviderHasTLSFunc describes the behavior when the HasTLS method of the
// parent MockProvider instance is invoked.
type ProviderHasTLSFunc struct {
defaultHook func() bool
hooks []func() bool
history []ProviderHasTLSFuncCall
mutex sync.Mutex
}
// HasTLS delegates to the next hook function in the queue and stores the
// parameter and result values of this invocation.
func (m *MockProvider) HasTLS() bool {
r0 := m.HasTLSFunc.nextHook()()
m.HasTLSFunc.appendCall(ProviderHasTLSFuncCall{r0})
return r0
}
// SetDefaultHook sets function that is called when the HasTLS method of the
// parent MockProvider instance is invoked and the hook queue is empty.
func (f *ProviderHasTLSFunc) SetDefaultHook(hook func() bool) {
f.defaultHook = hook
}
// PushHook adds a function to the end of hook queue. Each invocation of the
// HasTLS method of the parent MockProvider instance invokes the hook at the
// front of the queue and discards it. After the queue is empty, the default
// hook function is invoked for any future action.
func (f *ProviderHasTLSFunc) PushHook(hook func() bool) {
f.mutex.Lock()
f.hooks = append(f.hooks, hook)
f.mutex.Unlock()
}
// SetDefaultReturn calls SetDefaultHook with a function that returns the
// given values.
func (f *ProviderHasTLSFunc) SetDefaultReturn(r0 bool) {
f.SetDefaultHook(func() bool {
return r0
})
}
// PushReturn calls PushHook with a function that returns the given values.
func (f *ProviderHasTLSFunc) PushReturn(r0 bool) {
f.PushHook(func() bool {
return r0
})
}
func (f *ProviderHasTLSFunc) nextHook() func() bool {
f.mutex.Lock()
defer f.mutex.Unlock()
if len(f.hooks) == 0 {
return f.defaultHook
}
hook := f.hooks[0]
f.hooks = f.hooks[1:]
return hook
}
func (f *ProviderHasTLSFunc) appendCall(r0 ProviderHasTLSFuncCall) {
f.mutex.Lock()
f.history = append(f.history, r0)
f.mutex.Unlock()
}
// History returns a sequence of ProviderHasTLSFuncCall objects describing
// the invocations of this function.
func (f *ProviderHasTLSFunc) History() []ProviderHasTLSFuncCall {
f.mutex.Lock()
history := make([]ProviderHasTLSFuncCall, len(f.history))
copy(history, f.history)
f.mutex.Unlock()
return history
}
// ProviderHasTLSFuncCall is an object that describes an invocation of
// method HasTLS on an instance of MockProvider.
type ProviderHasTLSFuncCall struct {
// Result0 is the value of the 1st result returned from this method
// invocation.
Result0 bool
}
// Args returns an interface slice containing the arguments of this
// invocation.
func (c ProviderHasTLSFuncCall) Args() []interface{} {
return []interface{}{}
}
// Results returns an interface slice containing the results of this
// invocation.
func (c ProviderHasTLSFuncCall) Results() []interface{} {
return []interface{}{c.Result0}
}
// ProviderSkipTLSVerifyFunc describes the behavior when the SkipTLSVerify
// method of the parent MockProvider instance is invoked.
type ProviderSkipTLSVerifyFunc struct {
defaultHook func() bool
hooks []func() bool
history []ProviderSkipTLSVerifyFuncCall
mutex sync.Mutex
}
// SkipTLSVerify delegates to the next hook function in the queue and stores
// the parameter and result values of this invocation.
func (m *MockProvider) SkipTLSVerify() bool {
r0 := m.SkipTLSVerifyFunc.nextHook()()
m.SkipTLSVerifyFunc.appendCall(ProviderSkipTLSVerifyFuncCall{r0})
return r0
}
// SetDefaultHook sets function that is called when the SkipTLSVerify method
// of the parent MockProvider instance is invoked and the hook queue is
// empty.
func (f *ProviderSkipTLSVerifyFunc) SetDefaultHook(hook func() bool) {
f.defaultHook = hook
}
// PushHook adds a function to the end of hook queue. Each invocation of the
// SkipTLSVerify method of the parent MockProvider instance invokes the hook
// at the front of the queue and discards it. After the queue is empty, the
// default hook function is invoked for any future action.
func (f *ProviderSkipTLSVerifyFunc) PushHook(hook func() bool) {
f.mutex.Lock()
f.hooks = append(f.hooks, hook)
f.mutex.Unlock()
}
// SetDefaultReturn calls SetDefaultHook with a function that returns the
// given values.
func (f *ProviderSkipTLSVerifyFunc) SetDefaultReturn(r0 bool) {
f.SetDefaultHook(func() bool {
return r0
})
}
// PushReturn calls PushHook with a function that returns the given values.
func (f *ProviderSkipTLSVerifyFunc) PushReturn(r0 bool) {
f.PushHook(func() bool {
return r0
})
}
func (f *ProviderSkipTLSVerifyFunc) nextHook() func() bool {
f.mutex.Lock()
defer f.mutex.Unlock()
if len(f.hooks) == 0 {
return f.defaultHook
}
hook := f.hooks[0]
f.hooks = f.hooks[1:]
return hook
}
func (f *ProviderSkipTLSVerifyFunc) appendCall(r0 ProviderSkipTLSVerifyFuncCall) {
f.mutex.Lock()
f.history = append(f.history, r0)
f.mutex.Unlock()
}
// History returns a sequence of ProviderSkipTLSVerifyFuncCall objects
// describing the invocations of this function.
func (f *ProviderSkipTLSVerifyFunc) History() []ProviderSkipTLSVerifyFuncCall {
f.mutex.Lock()
history := make([]ProviderSkipTLSVerifyFuncCall, len(f.history))
copy(history, f.history)
f.mutex.Unlock()
return history
}
// ProviderSkipTLSVerifyFuncCall is an object that describes an invocation
// of method SkipTLSVerify on an instance of MockProvider.
type ProviderSkipTLSVerifyFuncCall struct {
// Result0 is the value of the 1st result returned from this method
// invocation.
Result0 bool
}
// Args returns an interface slice containing the arguments of this
// invocation.
func (c ProviderSkipTLSVerifyFuncCall) Args() []interface{} {
return []interface{}{}
}
// Results returns an interface slice containing the results of this
// invocation.
func (c ProviderSkipTLSVerifyFuncCall) Results() []interface{} {
return []interface{}{c.Result0}
}
// ProviderUseTLSFunc describes the behavior when the UseTLS method of the
// parent MockProvider instance is invoked.
type ProviderUseTLSFunc struct {
defaultHook func() bool
hooks []func() bool
history []ProviderUseTLSFuncCall
mutex sync.Mutex
}
// UseTLS delegates to the next hook function in the queue and stores the
// parameter and result values of this invocation.
func (m *MockProvider) UseTLS() bool {
r0 := m.UseTLSFunc.nextHook()()
m.UseTLSFunc.appendCall(ProviderUseTLSFuncCall{r0})
return r0
}
// SetDefaultHook sets function that is called when the UseTLS method of the
// parent MockProvider instance is invoked and the hook queue is empty.
func (f *ProviderUseTLSFunc) SetDefaultHook(hook func() bool) {
f.defaultHook = hook
}
// PushHook adds a function to the end of hook queue. Each invocation of the
// UseTLS method of the parent MockProvider instance invokes the hook at the
// front of the queue and discards it. After the queue is empty, the default
// hook function is invoked for any future action.
func (f *ProviderUseTLSFunc) PushHook(hook func() bool) {
f.mutex.Lock()
f.hooks = append(f.hooks, hook)
f.mutex.Unlock()
}
// SetDefaultReturn calls SetDefaultHook with a function that returns the
// given values.
func (f *ProviderUseTLSFunc) SetDefaultReturn(r0 bool) {
f.SetDefaultHook(func() bool {
return r0
})
}
// PushReturn calls PushHook with a function that returns the given values.
func (f *ProviderUseTLSFunc) PushReturn(r0 bool) {
f.PushHook(func() bool {
return r0
})
}
func (f *ProviderUseTLSFunc) nextHook() func() bool {
f.mutex.Lock()
defer f.mutex.Unlock()
if len(f.hooks) == 0 {
return f.defaultHook
}
hook := f.hooks[0]
f.hooks = f.hooks[1:]
return hook
}
func (f *ProviderUseTLSFunc) appendCall(r0 ProviderUseTLSFuncCall) {
f.mutex.Lock()
f.history = append(f.history, r0)
f.mutex.Unlock()
}
// History returns a sequence of ProviderUseTLSFuncCall objects describing
// the invocations of this function.
func (f *ProviderUseTLSFunc) History() []ProviderUseTLSFuncCall {
f.mutex.Lock()
history := make([]ProviderUseTLSFuncCall, len(f.history))
copy(history, f.history)
f.mutex.Unlock()
return history
}
// ProviderUseTLSFuncCall is an object that describes an invocation of
// method UseTLS on an instance of MockProvider.
type ProviderUseTLSFuncCall struct {
// Result0 is the value of the 1st result returned from this method
// invocation.
Result0 bool
}
// Args returns an interface slice containing the arguments of this
// invocation.
func (c ProviderUseTLSFuncCall) Args() []interface{} {
return []interface{}{}
}
// Results returns an interface slice containing the results of this
// invocation.
func (c ProviderUseTLSFuncCall) Results() []interface{} {
return []interface{}{c.Result0}
}

File diff suppressed because it is too large Load Diff

View File

@ -224,7 +224,7 @@ func GetStatistic(ctx context.Context) (stats Statistic) {
stats.Counter.Follow, _ = x.Count(new(Follow))
stats.Counter.Mirror, _ = x.Count(new(Mirror))
stats.Counter.Release, _ = x.Count(new(Release))
stats.Counter.LoginSource = LoginSources.Count(ctx)
stats.Counter.LoginSource = Handle.LoginSources().Count(ctx)
stats.Counter.Webhook, _ = x.Count(new(Webhook))
stats.Counter.Milestone, _ = x.Count(new(Milestone))
stats.Counter.Label, _ = x.Count(new(Label))

View File

@ -185,7 +185,7 @@ func (s *usersStore) Authenticate(ctx context.Context, login, password string, l
user := new(User)
err := query.First(user).Error
if err != nil && err != gorm.ErrRecordNotFound {
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.Wrap(err, "get user")
}
@ -221,7 +221,7 @@ func (s *usersStore) Authenticate(ctx context.Context, login, password string, l
createNewUser = true
}
source, err := LoginSources.GetByID(ctx, authSourceID)
source, err := newLoginSourcesStore(s.DB, loadedLoginSourceFilesStore).GetByID(ctx, authSourceID)
if err != nil {
return nil, errors.Wrap(err, "get login source")
}

View File

@ -175,17 +175,19 @@ func usersAuthenticate(t *testing.T, ctx context.Context, db *usersStore) {
})
t.Run("via login source", func(t *testing.T) {
mockLoginSources := NewMockLoginSourcesStore()
mockLoginSources.GetByIDFunc.SetDefaultHook(func(ctx context.Context, id int64) (*LoginSource, error) {
mockProvider := NewMockProvider()
mockProvider.AuthenticateFunc.SetDefaultReturn(&auth.ExternalAccount{}, nil)
s := &LoginSource{
IsActived: true,
Provider: mockProvider,
}
return s, nil
})
setMockLoginSourcesStore(t, mockLoginSources)
loginSourcesStore := newLoginSourcesStore(db.DB, NewMockLoginSourceFilesStore())
loginSource, err := loginSourcesStore.Create(
ctx,
CreateLoginSourceOptions{
Type: auth.Mock,
Name: "mock-1",
Activated: true,
Config: mockProviderConfig{
ExternalAccount: &auth.ExternalAccount{},
},
},
)
require.NoError(t, err)
bob, err := db.Create(ctx, "bob", "bob@example.com",
CreateUserOptions{
@ -195,31 +197,30 @@ func usersAuthenticate(t *testing.T, ctx context.Context, db *usersStore) {
)
require.NoError(t, err)
user, err := db.Authenticate(ctx, bob.Email, password, 1)
user, err := db.Authenticate(ctx, bob.Email, password, loginSource.ID)
require.NoError(t, err)
assert.Equal(t, bob.Name, user.Name)
})
t.Run("new user via login source", func(t *testing.T) {
mockLoginSources := NewMockLoginSourcesStore()
mockLoginSources.GetByIDFunc.SetDefaultHook(func(ctx context.Context, id int64) (*LoginSource, error) {
mockProvider := NewMockProvider()
mockProvider.AuthenticateFunc.SetDefaultReturn(
&auth.ExternalAccount{
Name: "cindy",
Email: "cindy@example.com",
loginSourcesStore := newLoginSourcesStore(db.DB, NewMockLoginSourceFilesStore())
loginSource, err := loginSourcesStore.Create(
ctx,
CreateLoginSourceOptions{
Type: auth.Mock,
Name: "mock-2",
Activated: true,
Config: mockProviderConfig{
ExternalAccount: &auth.ExternalAccount{
Name: "cindy",
Email: "cindy@example.com",
},
},
nil,
)
s := &LoginSource{
IsActived: true,
Provider: mockProvider,
}
return s, nil
})
setMockLoginSourcesStore(t, mockLoginSources)
},
)
require.NoError(t, err)
user, err := db.Authenticate(ctx, "cindy", password, 1)
user, err := db.Authenticate(ctx, "cindy", password, loginSource.ID)
require.NoError(t, err)
assert.Equal(t, "cindy", user.Name)

View File

@ -35,13 +35,13 @@ func Authentications(c *context.Context) {
c.PageIs("AdminAuthentications")
var err error
c.Data["Sources"], err = database.LoginSources.List(c.Req.Context(), database.ListLoginSourceOptions{})
c.Data["Sources"], err = database.Handle.LoginSources().List(c.Req.Context(), database.ListLoginSourceOptions{})
if err != nil {
c.Error(err, "list login sources")
return
}
c.Data["Total"] = database.LoginSources.Count(c.Req.Context())
c.Data["Total"] = database.Handle.LoginSources().Count(c.Req.Context())
c.Success(AUTHS)
}
@ -159,7 +159,7 @@ func NewAuthSourcePost(c *context.Context, f form.Authentication) {
return
}
source, err := database.LoginSources.Create(c.Req.Context(),
source, err := database.Handle.LoginSources().Create(c.Req.Context(),
database.CreateLoginSourceOptions{
Type: auth.Type(f.Type),
Name: f.Name,
@ -179,7 +179,7 @@ func NewAuthSourcePost(c *context.Context, f form.Authentication) {
}
if source.IsDefault {
err = database.LoginSources.ResetNonDefault(c.Req.Context(), source)
err = database.Handle.LoginSources().ResetNonDefault(c.Req.Context(), source)
if err != nil {
c.Error(err, "reset non-default login sources")
return
@ -200,7 +200,7 @@ func EditAuthSource(c *context.Context) {
c.Data["SecurityProtocols"] = securityProtocols
c.Data["SMTPAuths"] = smtp.AuthTypes
source, err := database.LoginSources.GetByID(c.Req.Context(), c.ParamsInt64(":authid"))
source, err := database.Handle.LoginSources().GetByID(c.Req.Context(), c.ParamsInt64(":authid"))
if err != nil {
c.Error(err, "get login source by ID")
return
@ -218,7 +218,7 @@ func EditAuthSourcePost(c *context.Context, f form.Authentication) {
c.Data["SMTPAuths"] = smtp.AuthTypes
source, err := database.LoginSources.GetByID(c.Req.Context(), c.ParamsInt64(":authid"))
source, err := database.Handle.LoginSources().GetByID(c.Req.Context(), c.ParamsInt64(":authid"))
if err != nil {
c.Error(err, "get login source by ID")
return
@ -257,13 +257,13 @@ func EditAuthSourcePost(c *context.Context, f form.Authentication) {
source.IsActived = f.IsActive
source.IsDefault = f.IsDefault
source.Provider = provider
if err := database.LoginSources.Save(c.Req.Context(), source); err != nil {
if err := database.Handle.LoginSources().Save(c.Req.Context(), source); err != nil {
c.Error(err, "update login source")
return
}
if source.IsDefault {
err = database.LoginSources.ResetNonDefault(c.Req.Context(), source)
err = database.Handle.LoginSources().ResetNonDefault(c.Req.Context(), source)
if err != nil {
c.Error(err, "reset non-default login sources")
return
@ -278,7 +278,7 @@ func EditAuthSourcePost(c *context.Context, f form.Authentication) {
func DeleteAuthSource(c *context.Context) {
id := c.ParamsInt64(":authid")
if err := database.LoginSources.DeleteByID(c.Req.Context(), id); err != nil {
if err := database.Handle.LoginSources().DeleteByID(c.Req.Context(), id); err != nil {
if database.IsErrLoginSourceInUse(err) {
c.Flash.Error(c.Tr("admin.auths.still_in_used"))
} else {

View File

@ -46,7 +46,7 @@ func NewUser(c *context.Context) {
c.Data["login_type"] = "0-0"
sources, err := database.LoginSources.List(c.Req.Context(), database.ListLoginSourceOptions{})
sources, err := database.Handle.LoginSources().List(c.Req.Context(), database.ListLoginSourceOptions{})
if err != nil {
c.Error(err, "list login sources")
return
@ -62,7 +62,7 @@ func NewUserPost(c *context.Context, f form.AdminCrateUser) {
c.Data["PageIsAdmin"] = true
c.Data["PageIsAdminUsers"] = true
sources, err := database.LoginSources.List(c.Req.Context(), database.ListLoginSourceOptions{})
sources, err := database.Handle.LoginSources().List(c.Req.Context(), database.ListLoginSourceOptions{})
if err != nil {
c.Error(err, "list login sources")
return
@ -125,7 +125,7 @@ func prepareUserInfo(c *context.Context) *database.User {
c.Data["User"] = u
if u.LoginSource > 0 {
c.Data["LoginSource"], err = database.LoginSources.GetByID(c.Req.Context(), u.LoginSource)
c.Data["LoginSource"], err = database.Handle.LoginSources().GetByID(c.Req.Context(), u.LoginSource)
if err != nil {
c.Error(err, "get login source by ID")
return nil
@ -134,7 +134,7 @@ func prepareUserInfo(c *context.Context) *database.User {
c.Data["LoginSource"] = &database.LoginSource{}
}
sources, err := database.LoginSources.List(c.Req.Context(), database.ListLoginSourceOptions{})
sources, err := database.Handle.LoginSources().List(c.Req.Context(), database.ListLoginSourceOptions{})
if err != nil {
c.Error(err, "list login sources")
return nil

View File

@ -22,7 +22,7 @@ func parseLoginSource(c *context.APIContext, sourceID int64) {
return
}
_, err := database.LoginSources.GetByID(c.Req.Context(), sourceID)
_, err := database.Handle.LoginSources().GetByID(c.Req.Context(), sourceID)
if err != nil {
if database.IsErrLoginSourceNotExist(err) {
c.ErrorStatus(http.StatusUnprocessableEntity, err)

View File

@ -106,7 +106,7 @@ func Login(c *context.Context) {
}
// Display normal login page
loginSources, err := database.LoginSources.List(c.Req.Context(), database.ListLoginSourceOptions{OnlyActivated: true})
loginSources, err := database.Handle.LoginSources().List(c.Req.Context(), database.ListLoginSourceOptions{OnlyActivated: true})
if err != nil {
c.Error(err, "list activated login sources")
return
@ -153,7 +153,7 @@ func afterLogin(c *context.Context, u *database.User, remember bool) {
func LoginPost(c *context.Context, f form.SignIn) {
c.Title("sign_in")
loginSources, err := database.LoginSources.List(c.Req.Context(), database.ListLoginSourceOptions{OnlyActivated: true})
loginSources, err := database.Handle.LoginSources().List(c.Req.Context(), database.ListLoginSourceOptions{OnlyActivated: true})
if err != nil {
c.Error(err, "list activated login sources")
return

View File

@ -25,10 +25,10 @@ mocks:
sources:
- path: gogs.io/gogs/internal/database
interfaces:
- LoginSourcesStore
- LoginSourceFilesStore
- LoginSourceFileStore
- loginSourceFileStore
- loginSourceFilesStore
- filename: internal/database/mocks_gen.go
sources:
- path: gogs.io/gogs/internal/auth
interfaces:
- Provider