mirror of https://github.com/gogs/gogs.git
all: unwrap `database.LoginSourcesStore` interface (#7694)
parent
3a5132b6f7
commit
e634aa6277
|
@ -23,6 +23,8 @@ const (
|
|||
PAM // 4
|
||||
DLDAP // 5
|
||||
GitHub // 6
|
||||
|
||||
Mock Type = 999
|
||||
)
|
||||
|
||||
// Name returns the human-readable name for given authentication type.
|
||||
|
@ -45,8 +47,7 @@ type ErrBadCredentials struct {
|
|||
// IsErrBadCredentials returns true if the underlying error has the type
|
||||
// ErrBadCredentials.
|
||||
func IsErrBadCredentials(err error) bool {
|
||||
_, ok := errors.Cause(err).(ErrBadCredentials)
|
||||
return ok
|
||||
return errors.As(err, &ErrBadCredentials{})
|
||||
}
|
||||
|
||||
func (err ErrBadCredentials) Error() string {
|
||||
|
|
|
@ -117,13 +117,12 @@ func NewConnection(w logger.Writer) (*gorm.DB, error) {
|
|||
log.Trace("Auto migrated %q", name)
|
||||
}
|
||||
|
||||
sourceFiles, err := loadLoginSourceFiles(filepath.Join(conf.CustomDir(), "conf", "auth.d"), db.NowFunc)
|
||||
loadedLoginSourceFilesStore, err = loadLoginSourceFiles(filepath.Join(conf.CustomDir(), "conf", "auth.d"), db.NowFunc)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "load login source files")
|
||||
}
|
||||
|
||||
// Initialize stores, sorted in alphabetical order.
|
||||
LoginSources = &loginSourcesStore{DB: db, files: sourceFiles}
|
||||
Notices = NewNoticesStore(db)
|
||||
Orgs = NewOrgsStore(db)
|
||||
Perms = NewPermsStore(db)
|
||||
|
@ -166,3 +165,11 @@ func (db *DB) Actions() *ActionsStore {
|
|||
func (db *DB) LFS() *LFSStore {
|
||||
return newLFSStore(db.db)
|
||||
}
|
||||
|
||||
// NOTE: It is not guarded by a mutex because it only gets written during the
|
||||
// service start.
|
||||
var loadedLoginSourceFilesStore loginSourceFilesStore
|
||||
|
||||
func (db *DB) LoginSources() *LoginSourcesStore {
|
||||
return newLoginSourcesStore(db.db, loadedLoginSourceFilesStore)
|
||||
}
|
||||
|
|
|
@ -52,8 +52,7 @@ type ErrLoginSourceNotExist struct {
|
|||
}
|
||||
|
||||
func IsErrLoginSourceNotExist(err error) bool {
|
||||
_, ok := err.(ErrLoginSourceNotExist)
|
||||
return ok
|
||||
return errors.As(err, &ErrLoginSourceNotExist{})
|
||||
}
|
||||
|
||||
func (err ErrLoginSourceNotExist) Error() string {
|
||||
|
|
|
@ -22,30 +22,6 @@ import (
|
|||
"gogs.io/gogs/internal/errutil"
|
||||
)
|
||||
|
||||
// LoginSourcesStore is the persistent interface for login sources.
|
||||
type LoginSourcesStore interface {
|
||||
// Create creates a new login source and persist to database. It returns
|
||||
// ErrLoginSourceAlreadyExist when a login source with same name already exists.
|
||||
Create(ctx context.Context, opts CreateLoginSourceOptions) (*LoginSource, error)
|
||||
// Count returns the total number of login sources.
|
||||
Count(ctx context.Context) int64
|
||||
// DeleteByID deletes a login source by given ID. It returns ErrLoginSourceInUse
|
||||
// if at least one user is associated with the login source.
|
||||
DeleteByID(ctx context.Context, id int64) error
|
||||
// GetByID returns the login source with given ID. It returns
|
||||
// ErrLoginSourceNotExist when not found.
|
||||
GetByID(ctx context.Context, id int64) (*LoginSource, error)
|
||||
// List returns a list of login sources filtered by options.
|
||||
List(ctx context.Context, opts ListLoginSourceOptions) ([]*LoginSource, error)
|
||||
// ResetNonDefault clears default flag for all the other login sources.
|
||||
ResetNonDefault(ctx context.Context, source *LoginSource) error
|
||||
// Save persists all values of given login source to database or local file. The
|
||||
// Updated field is set to current time automatically.
|
||||
Save(ctx context.Context, t *LoginSource) error
|
||||
}
|
||||
|
||||
var LoginSources LoginSourcesStore
|
||||
|
||||
// LoginSource represents an external way for authorizing users.
|
||||
type LoginSource struct {
|
||||
ID int64 `gorm:"primaryKey"`
|
||||
|
@ -88,6 +64,10 @@ func (s *LoginSource) BeforeUpdate(tx *gorm.DB) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
type mockProviderConfig struct {
|
||||
ExternalAccount *auth.ExternalAccount
|
||||
}
|
||||
|
||||
// AfterFind implements the GORM query hook.
|
||||
func (s *LoginSource) AfterFind(_ *gorm.DB) error {
|
||||
s.Created = time.Unix(s.CreatedUnix, 0).Local()
|
||||
|
@ -134,6 +114,16 @@ func (s *LoginSource) AfterFind(_ *gorm.DB) error {
|
|||
}
|
||||
s.Provider = github.NewProvider(&cfg)
|
||||
|
||||
case auth.Mock:
|
||||
var cfg mockProviderConfig
|
||||
err := jsoniter.UnmarshalFromString(s.Config, &cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mockProvider := NewMockProvider()
|
||||
mockProvider.AuthenticateFunc.SetDefaultReturn(cfg.ExternalAccount, nil)
|
||||
s.Provider = mockProvider
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unrecognized login source type: %v", s.Type)
|
||||
}
|
||||
|
@ -180,13 +170,19 @@ func (s *LoginSource) GitHub() *github.Config {
|
|||
return s.Provider.Config().(*github.Config)
|
||||
}
|
||||
|
||||
var _ LoginSourcesStore = (*loginSourcesStore)(nil)
|
||||
|
||||
type loginSourcesStore struct {
|
||||
*gorm.DB
|
||||
// LoginSourcesStore is the storage layer for login sources.
|
||||
type LoginSourcesStore struct {
|
||||
db *gorm.DB
|
||||
files loginSourceFilesStore
|
||||
}
|
||||
|
||||
func newLoginSourcesStore(db *gorm.DB, files loginSourceFilesStore) *LoginSourcesStore {
|
||||
return &LoginSourcesStore{
|
||||
db: db,
|
||||
files: files,
|
||||
}
|
||||
}
|
||||
|
||||
type CreateLoginSourceOptions struct {
|
||||
Type auth.Type
|
||||
Name string
|
||||
|
@ -200,19 +196,20 @@ type ErrLoginSourceAlreadyExist struct {
|
|||
}
|
||||
|
||||
func IsErrLoginSourceAlreadyExist(err error) bool {
|
||||
_, ok := err.(ErrLoginSourceAlreadyExist)
|
||||
return ok
|
||||
return errors.As(err, &ErrLoginSourceAlreadyExist{})
|
||||
}
|
||||
|
||||
func (err ErrLoginSourceAlreadyExist) Error() string {
|
||||
return fmt.Sprintf("login source already exists: %v", err.args)
|
||||
}
|
||||
|
||||
func (s *loginSourcesStore) Create(ctx context.Context, opts CreateLoginSourceOptions) (*LoginSource, error) {
|
||||
err := s.WithContext(ctx).Where("name = ?", opts.Name).First(new(LoginSource)).Error
|
||||
// Create creates a new login source and persists it to the database. It returns
|
||||
// ErrLoginSourceAlreadyExist when a login source with same name already exists.
|
||||
func (s *LoginSourcesStore) Create(ctx context.Context, opts CreateLoginSourceOptions) (*LoginSource, error) {
|
||||
err := s.db.WithContext(ctx).Where("name = ?", opts.Name).First(new(LoginSource)).Error
|
||||
if err == nil {
|
||||
return nil, ErrLoginSourceAlreadyExist{args: errutil.Args{"name": opts.Name}}
|
||||
} else if err != gorm.ErrRecordNotFound {
|
||||
} else if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -226,12 +223,13 @@ func (s *loginSourcesStore) Create(ctx context.Context, opts CreateLoginSourceOp
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return source, s.WithContext(ctx).Create(source).Error
|
||||
return source, s.db.WithContext(ctx).Create(source).Error
|
||||
}
|
||||
|
||||
func (s *loginSourcesStore) Count(ctx context.Context) int64 {
|
||||
// Count returns the total number of login sources.
|
||||
func (s *LoginSourcesStore) Count(ctx context.Context) int64 {
|
||||
var count int64
|
||||
s.WithContext(ctx).Model(new(LoginSource)).Count(&count)
|
||||
s.db.WithContext(ctx).Model(new(LoginSource)).Count(&count)
|
||||
return count + int64(s.files.Len())
|
||||
}
|
||||
|
||||
|
@ -240,31 +238,34 @@ type ErrLoginSourceInUse struct {
|
|||
}
|
||||
|
||||
func IsErrLoginSourceInUse(err error) bool {
|
||||
_, ok := err.(ErrLoginSourceInUse)
|
||||
return ok
|
||||
return errors.As(err, &ErrLoginSourceInUse{})
|
||||
}
|
||||
|
||||
func (err ErrLoginSourceInUse) Error() string {
|
||||
return fmt.Sprintf("login source is still used by some users: %v", err.args)
|
||||
}
|
||||
|
||||
func (s *loginSourcesStore) DeleteByID(ctx context.Context, id int64) error {
|
||||
// DeleteByID deletes a login source by given ID. It returns ErrLoginSourceInUse
|
||||
// if at least one user is associated with the login source.
|
||||
func (s *LoginSourcesStore) DeleteByID(ctx context.Context, id int64) error {
|
||||
var count int64
|
||||
err := s.WithContext(ctx).Model(new(User)).Where("login_source = ?", id).Count(&count).Error
|
||||
err := s.db.WithContext(ctx).Model(new(User)).Where("login_source = ?", id).Count(&count).Error
|
||||
if err != nil {
|
||||
return err
|
||||
} else if count > 0 {
|
||||
return ErrLoginSourceInUse{args: errutil.Args{"id": id}}
|
||||
}
|
||||
|
||||
return s.WithContext(ctx).Where("id = ?", id).Delete(new(LoginSource)).Error
|
||||
return s.db.WithContext(ctx).Where("id = ?", id).Delete(new(LoginSource)).Error
|
||||
}
|
||||
|
||||
func (s *loginSourcesStore) GetByID(ctx context.Context, id int64) (*LoginSource, error) {
|
||||
// GetByID returns the login source with given ID. It returns
|
||||
// ErrLoginSourceNotExist when not found.
|
||||
func (s *LoginSourcesStore) GetByID(ctx context.Context, id int64) (*LoginSource, error) {
|
||||
source := new(LoginSource)
|
||||
err := s.WithContext(ctx).Where("id = ?", id).First(source).Error
|
||||
err := s.db.WithContext(ctx).Where("id = ?", id).First(source).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return s.files.GetByID(id)
|
||||
}
|
||||
return nil, err
|
||||
|
@ -277,9 +278,10 @@ type ListLoginSourceOptions struct {
|
|||
OnlyActivated bool
|
||||
}
|
||||
|
||||
func (s *loginSourcesStore) List(ctx context.Context, opts ListLoginSourceOptions) ([]*LoginSource, error) {
|
||||
// List returns a list of login sources filtered by options.
|
||||
func (s *LoginSourcesStore) List(ctx context.Context, opts ListLoginSourceOptions) ([]*LoginSource, error) {
|
||||
var sources []*LoginSource
|
||||
query := s.WithContext(ctx).Order("id ASC")
|
||||
query := s.db.WithContext(ctx).Order("id ASC")
|
||||
if opts.OnlyActivated {
|
||||
query = query.Where("is_actived = ?", true)
|
||||
}
|
||||
|
@ -291,8 +293,9 @@ func (s *loginSourcesStore) List(ctx context.Context, opts ListLoginSourceOption
|
|||
return append(sources, s.files.List(opts)...), nil
|
||||
}
|
||||
|
||||
func (s *loginSourcesStore) ResetNonDefault(ctx context.Context, dflt *LoginSource) error {
|
||||
err := s.WithContext(ctx).
|
||||
// ResetNonDefault clears default flag for all the other login sources.
|
||||
func (s *LoginSourcesStore) ResetNonDefault(ctx context.Context, dflt *LoginSource) error {
|
||||
err := s.db.WithContext(ctx).
|
||||
Model(new(LoginSource)).
|
||||
Where("id != ?", dflt.ID).
|
||||
Updates(map[string]any{"is_default": false}).
|
||||
|
@ -314,9 +317,11 @@ func (s *loginSourcesStore) ResetNonDefault(ctx context.Context, dflt *LoginSour
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *loginSourcesStore) Save(ctx context.Context, source *LoginSource) error {
|
||||
// Save persists all values of given login source to database or local file. The
|
||||
// Updated field is set to current time automatically.
|
||||
func (s *LoginSourcesStore) Save(ctx context.Context, source *LoginSource) error {
|
||||
if source.File == nil {
|
||||
return s.WithContext(ctx).Save(source).Error
|
||||
return s.db.WithContext(ctx).Save(source).Error
|
||||
}
|
||||
|
||||
source.File.SetGeneral("name", source.Name)
|
||||
|
|
|
@ -163,13 +163,13 @@ func TestLoginSources(t *testing.T) {
|
|||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
db := &loginSourcesStore{
|
||||
DB: newTestDB(t, "loginSourcesStore"),
|
||||
s := &LoginSourcesStore{
|
||||
db: newTestDB(t, "LoginSourcesStore"),
|
||||
}
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
test func(t *testing.T, ctx context.Context, db *loginSourcesStore)
|
||||
test func(t *testing.T, ctx context.Context, s *LoginSourcesStore)
|
||||
}{
|
||||
{"Create", loginSourcesCreate},
|
||||
{"Count", loginSourcesCount},
|
||||
|
@ -181,10 +181,10 @@ func TestLoginSources(t *testing.T) {
|
|||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
err := clearTables(t, db.DB)
|
||||
err := clearTables(t, s.db)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
tc.test(t, ctx, db)
|
||||
tc.test(t, ctx, s)
|
||||
})
|
||||
if t.Failed() {
|
||||
break
|
||||
|
@ -192,9 +192,9 @@ func TestLoginSources(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func loginSourcesCreate(t *testing.T, ctx context.Context, db *loginSourcesStore) {
|
||||
func loginSourcesCreate(t *testing.T, ctx context.Context, s *LoginSourcesStore) {
|
||||
// Create first login source with name "GitHub"
|
||||
source, err := db.Create(ctx,
|
||||
source, err := s.Create(ctx,
|
||||
CreateLoginSourceOptions{
|
||||
Type: auth.GitHub,
|
||||
Name: "GitHub",
|
||||
|
@ -208,20 +208,28 @@ func loginSourcesCreate(t *testing.T, ctx context.Context, db *loginSourcesStore
|
|||
require.NoError(t, err)
|
||||
|
||||
// Get it back and check the Created field
|
||||
source, err = db.GetByID(ctx, source.ID)
|
||||
source, err = s.GetByID(ctx, source.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Created.UTC().Format(time.RFC3339))
|
||||
assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Updated.UTC().Format(time.RFC3339))
|
||||
assert.Equal(t, s.db.NowFunc().Format(time.RFC3339), source.Created.UTC().Format(time.RFC3339))
|
||||
assert.Equal(t, s.db.NowFunc().Format(time.RFC3339), source.Updated.UTC().Format(time.RFC3339))
|
||||
|
||||
// Try create second login source with same name should fail
|
||||
_, err = db.Create(ctx, CreateLoginSourceOptions{Name: source.Name})
|
||||
// Try to create second login source with same name should fail.
|
||||
_, err = s.Create(ctx, CreateLoginSourceOptions{Name: source.Name})
|
||||
wantErr := ErrLoginSourceAlreadyExist{args: errutil.Args{"name": source.Name}}
|
||||
assert.Equal(t, wantErr, err)
|
||||
}
|
||||
|
||||
func loginSourcesCount(t *testing.T, ctx context.Context, db *loginSourcesStore) {
|
||||
func setMockLoginSourceFilesStore(t *testing.T, s *LoginSourcesStore, mock loginSourceFilesStore) {
|
||||
before := s.files
|
||||
s.files = mock
|
||||
t.Cleanup(func() {
|
||||
s.files = before
|
||||
})
|
||||
}
|
||||
|
||||
func loginSourcesCount(t *testing.T, ctx context.Context, s *LoginSourcesStore) {
|
||||
// Create two login sources, one in database and one as source file.
|
||||
_, err := db.Create(ctx,
|
||||
_, err := s.Create(ctx,
|
||||
CreateLoginSourceOptions{
|
||||
Type: auth.GitHub,
|
||||
Name: "GitHub",
|
||||
|
@ -236,14 +244,14 @@ func loginSourcesCount(t *testing.T, ctx context.Context, db *loginSourcesStore)
|
|||
|
||||
mock := NewMockLoginSourceFilesStore()
|
||||
mock.LenFunc.SetDefaultReturn(2)
|
||||
setMockLoginSourceFilesStore(t, db, mock)
|
||||
setMockLoginSourceFilesStore(t, s, mock)
|
||||
|
||||
assert.Equal(t, int64(3), db.Count(ctx))
|
||||
assert.Equal(t, int64(3), s.Count(ctx))
|
||||
}
|
||||
|
||||
func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSourcesStore) {
|
||||
func loginSourcesDeleteByID(t *testing.T, ctx context.Context, s *LoginSourcesStore) {
|
||||
t.Run("delete but in used", func(t *testing.T) {
|
||||
source, err := db.Create(ctx,
|
||||
source, err := s.Create(ctx,
|
||||
CreateLoginSourceOptions{
|
||||
Type: auth.GitHub,
|
||||
Name: "GitHub",
|
||||
|
@ -257,7 +265,7 @@ func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSourcesS
|
|||
require.NoError(t, err)
|
||||
|
||||
// Create a user that uses this login source
|
||||
_, err = (&usersStore{DB: db.DB}).Create(ctx, "alice", "",
|
||||
_, err = NewUsersStore(s.db).Create(ctx, "alice", "",
|
||||
CreateUserOptions{
|
||||
LoginSource: source.ID,
|
||||
},
|
||||
|
@ -265,7 +273,7 @@ func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSourcesS
|
|||
require.NoError(t, err)
|
||||
|
||||
// Delete the login source will result in error
|
||||
err = db.DeleteByID(ctx, source.ID)
|
||||
err = s.DeleteByID(ctx, source.ID)
|
||||
wantErr := ErrLoginSourceInUse{args: errutil.Args{"id": source.ID}}
|
||||
assert.Equal(t, wantErr, err)
|
||||
})
|
||||
|
@ -274,10 +282,10 @@ func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSourcesS
|
|||
mock.GetByIDFunc.SetDefaultHook(func(id int64) (*LoginSource, error) {
|
||||
return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
|
||||
})
|
||||
setMockLoginSourceFilesStore(t, db, mock)
|
||||
setMockLoginSourceFilesStore(t, s, mock)
|
||||
|
||||
// Create a login source with name "GitHub2"
|
||||
source, err := db.Create(ctx,
|
||||
source, err := s.Create(ctx,
|
||||
CreateLoginSourceOptions{
|
||||
Type: auth.GitHub,
|
||||
Name: "GitHub2",
|
||||
|
@ -291,24 +299,24 @@ func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSourcesS
|
|||
require.NoError(t, err)
|
||||
|
||||
// Delete a non-existent ID is noop
|
||||
err = db.DeleteByID(ctx, 9999)
|
||||
err = s.DeleteByID(ctx, 9999)
|
||||
require.NoError(t, err)
|
||||
|
||||
// We should be able to get it back
|
||||
_, err = db.GetByID(ctx, source.ID)
|
||||
_, err = s.GetByID(ctx, source.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now delete this login source with ID
|
||||
err = db.DeleteByID(ctx, source.ID)
|
||||
err = s.DeleteByID(ctx, source.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// We should get token not found error
|
||||
_, err = db.GetByID(ctx, source.ID)
|
||||
_, err = s.GetByID(ctx, source.ID)
|
||||
wantErr := ErrLoginSourceNotExist{args: errutil.Args{"id": source.ID}}
|
||||
assert.Equal(t, wantErr, err)
|
||||
}
|
||||
|
||||
func loginSourcesGetByID(t *testing.T, ctx context.Context, db *loginSourcesStore) {
|
||||
func loginSourcesGetByID(t *testing.T, ctx context.Context, s *LoginSourcesStore) {
|
||||
mock := NewMockLoginSourceFilesStore()
|
||||
mock.GetByIDFunc.SetDefaultHook(func(id int64) (*LoginSource, error) {
|
||||
if id != 101 {
|
||||
|
@ -316,14 +324,14 @@ func loginSourcesGetByID(t *testing.T, ctx context.Context, db *loginSourcesStor
|
|||
}
|
||||
return &LoginSource{ID: id}, nil
|
||||
})
|
||||
setMockLoginSourceFilesStore(t, db, mock)
|
||||
setMockLoginSourceFilesStore(t, s, mock)
|
||||
|
||||
expConfig := &github.Config{
|
||||
APIEndpoint: "https://api.github.com",
|
||||
}
|
||||
|
||||
// Create a login source with name "GitHub"
|
||||
source, err := db.Create(ctx,
|
||||
source, err := s.Create(ctx,
|
||||
CreateLoginSourceOptions{
|
||||
Type: auth.GitHub,
|
||||
Name: "GitHub",
|
||||
|
@ -335,16 +343,16 @@ func loginSourcesGetByID(t *testing.T, ctx context.Context, db *loginSourcesStor
|
|||
require.NoError(t, err)
|
||||
|
||||
// Get the one in the database and test the read/write hooks
|
||||
source, err = db.GetByID(ctx, source.ID)
|
||||
source, err = s.GetByID(ctx, source.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expConfig, source.Provider.Config())
|
||||
|
||||
// Get the one in source file store
|
||||
_, err = db.GetByID(ctx, 101)
|
||||
_, err = s.GetByID(ctx, 101)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func loginSourcesList(t *testing.T, ctx context.Context, db *loginSourcesStore) {
|
||||
func loginSourcesList(t *testing.T, ctx context.Context, s *LoginSourcesStore) {
|
||||
mock := NewMockLoginSourceFilesStore()
|
||||
mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource {
|
||||
if opts.OnlyActivated {
|
||||
|
@ -357,10 +365,10 @@ func loginSourcesList(t *testing.T, ctx context.Context, db *loginSourcesStore)
|
|||
{ID: 2},
|
||||
}
|
||||
})
|
||||
setMockLoginSourceFilesStore(t, db, mock)
|
||||
setMockLoginSourceFilesStore(t, s, mock)
|
||||
|
||||
// Create two login sources in database, one activated and the other one not
|
||||
_, err := db.Create(ctx,
|
||||
_, err := s.Create(ctx,
|
||||
CreateLoginSourceOptions{
|
||||
Type: auth.PAM,
|
||||
Name: "PAM",
|
||||
|
@ -370,7 +378,7 @@ func loginSourcesList(t *testing.T, ctx context.Context, db *loginSourcesStore)
|
|||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
_, err = db.Create(ctx,
|
||||
_, err = s.Create(ctx,
|
||||
CreateLoginSourceOptions{
|
||||
Type: auth.GitHub,
|
||||
Name: "GitHub",
|
||||
|
@ -383,17 +391,17 @@ func loginSourcesList(t *testing.T, ctx context.Context, db *loginSourcesStore)
|
|||
require.NoError(t, err)
|
||||
|
||||
// List all login sources
|
||||
sources, err := db.List(ctx, ListLoginSourceOptions{})
|
||||
sources, err := s.List(ctx, ListLoginSourceOptions{})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 4, len(sources), "number of sources")
|
||||
|
||||
// Only list activated login sources
|
||||
sources, err = db.List(ctx, ListLoginSourceOptions{OnlyActivated: true})
|
||||
sources, err = s.List(ctx, ListLoginSourceOptions{OnlyActivated: true})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, len(sources), "number of sources")
|
||||
}
|
||||
|
||||
func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSourcesStore) {
|
||||
func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, s *LoginSourcesStore) {
|
||||
mock := NewMockLoginSourceFilesStore()
|
||||
mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource {
|
||||
mockFile := NewMockLoginSourceFileStore()
|
||||
|
@ -407,10 +415,10 @@ func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSou
|
|||
},
|
||||
}
|
||||
})
|
||||
setMockLoginSourceFilesStore(t, db, mock)
|
||||
setMockLoginSourceFilesStore(t, s, mock)
|
||||
|
||||
// Create two login sources both have default on
|
||||
source1, err := db.Create(ctx,
|
||||
source1, err := s.Create(ctx,
|
||||
CreateLoginSourceOptions{
|
||||
Type: auth.PAM,
|
||||
Name: "PAM",
|
||||
|
@ -421,7 +429,7 @@ func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSou
|
|||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
source2, err := db.Create(ctx,
|
||||
source2, err := s.Create(ctx,
|
||||
CreateLoginSourceOptions{
|
||||
Type: auth.GitHub,
|
||||
Name: "GitHub",
|
||||
|
@ -435,23 +443,23 @@ func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSou
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set source 1 as default
|
||||
err = db.ResetNonDefault(ctx, source1)
|
||||
err = s.ResetNonDefault(ctx, source1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the default state
|
||||
source1, err = db.GetByID(ctx, source1.ID)
|
||||
source1, err = s.GetByID(ctx, source1.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, source1.IsDefault)
|
||||
|
||||
source2, err = db.GetByID(ctx, source2.ID)
|
||||
source2, err = s.GetByID(ctx, source2.ID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, source2.IsDefault)
|
||||
}
|
||||
|
||||
func loginSourcesSave(t *testing.T, ctx context.Context, db *loginSourcesStore) {
|
||||
func loginSourcesSave(t *testing.T, ctx context.Context, s *LoginSourcesStore) {
|
||||
t.Run("save to database", func(t *testing.T) {
|
||||
// Create a login source with name "GitHub"
|
||||
source, err := db.Create(ctx,
|
||||
source, err := s.Create(ctx,
|
||||
CreateLoginSourceOptions{
|
||||
Type: auth.GitHub,
|
||||
Name: "GitHub",
|
||||
|
@ -468,10 +476,10 @@ func loginSourcesSave(t *testing.T, ctx context.Context, db *loginSourcesStore)
|
|||
source.Provider = github.NewProvider(&github.Config{
|
||||
APIEndpoint: "https://api2.github.com",
|
||||
})
|
||||
err = db.Save(ctx, source)
|
||||
err = s.Save(ctx, source)
|
||||
require.NoError(t, err)
|
||||
|
||||
source, err = db.GetByID(ctx, source.ID)
|
||||
source, err = s.GetByID(ctx, source.ID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, source.IsActived)
|
||||
assert.Equal(t, "https://api2.github.com", source.GitHub().APIEndpoint)
|
||||
|
@ -485,7 +493,7 @@ func loginSourcesSave(t *testing.T, ctx context.Context, db *loginSourcesStore)
|
|||
}),
|
||||
File: mockFile,
|
||||
}
|
||||
err := db.Save(ctx, source)
|
||||
err := s.Save(ctx, source)
|
||||
require.NoError(t, err)
|
||||
mockrequire.Called(t, mockFile.SaveFunc)
|
||||
})
|
||||
|
|
|
@ -8,22 +8,6 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
func setMockLoginSourcesStore(t *testing.T, mock LoginSourcesStore) {
|
||||
before := LoginSources
|
||||
LoginSources = mock
|
||||
t.Cleanup(func() {
|
||||
LoginSources = before
|
||||
})
|
||||
}
|
||||
|
||||
func setMockLoginSourceFilesStore(t *testing.T, db *loginSourcesStore, mock loginSourceFilesStore) {
|
||||
before := db.files
|
||||
db.files = mock
|
||||
t.Cleanup(func() {
|
||||
db.files = before
|
||||
})
|
||||
}
|
||||
|
||||
func SetMockPermsStore(t *testing.T, mock PermsStore) {
|
||||
before := Perms
|
||||
Perms = mock
|
||||
|
|
|
@ -0,0 +1,620 @@
|
|||
// Code generated by go-mockgen 1.3.7; DO NOT EDIT.
|
||||
//
|
||||
// This file was generated by running `go-mockgen` at the root of this repository.
|
||||
// To add additional mocks to this or another package, add a new entry to the
|
||||
// mockgen.yaml file in the root of this repository.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
auth "gogs.io/gogs/internal/auth"
|
||||
)
|
||||
|
||||
// MockProvider is a mock implementation of the Provider interface (from the
|
||||
// package gogs.io/gogs/internal/auth) used for unit testing.
|
||||
type MockProvider struct {
|
||||
// AuthenticateFunc is an instance of a mock function object controlling
|
||||
// the behavior of the method Authenticate.
|
||||
AuthenticateFunc *ProviderAuthenticateFunc
|
||||
// ConfigFunc is an instance of a mock function object controlling the
|
||||
// behavior of the method Config.
|
||||
ConfigFunc *ProviderConfigFunc
|
||||
// HasTLSFunc is an instance of a mock function object controlling the
|
||||
// behavior of the method HasTLS.
|
||||
HasTLSFunc *ProviderHasTLSFunc
|
||||
// SkipTLSVerifyFunc is an instance of a mock function object
|
||||
// controlling the behavior of the method SkipTLSVerify.
|
||||
SkipTLSVerifyFunc *ProviderSkipTLSVerifyFunc
|
||||
// UseTLSFunc is an instance of a mock function object controlling the
|
||||
// behavior of the method UseTLS.
|
||||
UseTLSFunc *ProviderUseTLSFunc
|
||||
}
|
||||
|
||||
// NewMockProvider creates a new mock of the Provider interface. All methods
|
||||
// return zero values for all results, unless overwritten.
|
||||
func NewMockProvider() *MockProvider {
|
||||
return &MockProvider{
|
||||
AuthenticateFunc: &ProviderAuthenticateFunc{
|
||||
defaultHook: func(string, string) (r0 *auth.ExternalAccount, r1 error) {
|
||||
return
|
||||
},
|
||||
},
|
||||
ConfigFunc: &ProviderConfigFunc{
|
||||
defaultHook: func() (r0 interface{}) {
|
||||
return
|
||||
},
|
||||
},
|
||||
HasTLSFunc: &ProviderHasTLSFunc{
|
||||
defaultHook: func() (r0 bool) {
|
||||
return
|
||||
},
|
||||
},
|
||||
SkipTLSVerifyFunc: &ProviderSkipTLSVerifyFunc{
|
||||
defaultHook: func() (r0 bool) {
|
||||
return
|
||||
},
|
||||
},
|
||||
UseTLSFunc: &ProviderUseTLSFunc{
|
||||
defaultHook: func() (r0 bool) {
|
||||
return
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewStrictMockProvider creates a new mock of the Provider interface. All
|
||||
// methods panic on invocation, unless overwritten.
|
||||
func NewStrictMockProvider() *MockProvider {
|
||||
return &MockProvider{
|
||||
AuthenticateFunc: &ProviderAuthenticateFunc{
|
||||
defaultHook: func(string, string) (*auth.ExternalAccount, error) {
|
||||
panic("unexpected invocation of MockProvider.Authenticate")
|
||||
},
|
||||
},
|
||||
ConfigFunc: &ProviderConfigFunc{
|
||||
defaultHook: func() interface{} {
|
||||
panic("unexpected invocation of MockProvider.Config")
|
||||
},
|
||||
},
|
||||
HasTLSFunc: &ProviderHasTLSFunc{
|
||||
defaultHook: func() bool {
|
||||
panic("unexpected invocation of MockProvider.HasTLS")
|
||||
},
|
||||
},
|
||||
SkipTLSVerifyFunc: &ProviderSkipTLSVerifyFunc{
|
||||
defaultHook: func() bool {
|
||||
panic("unexpected invocation of MockProvider.SkipTLSVerify")
|
||||
},
|
||||
},
|
||||
UseTLSFunc: &ProviderUseTLSFunc{
|
||||
defaultHook: func() bool {
|
||||
panic("unexpected invocation of MockProvider.UseTLS")
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewMockProviderFrom creates a new mock of the MockProvider interface. All
|
||||
// methods delegate to the given implementation, unless overwritten.
|
||||
func NewMockProviderFrom(i auth.Provider) *MockProvider {
|
||||
return &MockProvider{
|
||||
AuthenticateFunc: &ProviderAuthenticateFunc{
|
||||
defaultHook: i.Authenticate,
|
||||
},
|
||||
ConfigFunc: &ProviderConfigFunc{
|
||||
defaultHook: i.Config,
|
||||
},
|
||||
HasTLSFunc: &ProviderHasTLSFunc{
|
||||
defaultHook: i.HasTLS,
|
||||
},
|
||||
SkipTLSVerifyFunc: &ProviderSkipTLSVerifyFunc{
|
||||
defaultHook: i.SkipTLSVerify,
|
||||
},
|
||||
UseTLSFunc: &ProviderUseTLSFunc{
|
||||
defaultHook: i.UseTLS,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ProviderAuthenticateFunc describes the behavior when the Authenticate
|
||||
// method of the parent MockProvider instance is invoked.
|
||||
type ProviderAuthenticateFunc struct {
|
||||
defaultHook func(string, string) (*auth.ExternalAccount, error)
|
||||
hooks []func(string, string) (*auth.ExternalAccount, error)
|
||||
history []ProviderAuthenticateFuncCall
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// Authenticate delegates to the next hook function in the queue and stores
|
||||
// the parameter and result values of this invocation.
|
||||
func (m *MockProvider) Authenticate(v0 string, v1 string) (*auth.ExternalAccount, error) {
|
||||
r0, r1 := m.AuthenticateFunc.nextHook()(v0, v1)
|
||||
m.AuthenticateFunc.appendCall(ProviderAuthenticateFuncCall{v0, v1, r0, r1})
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// SetDefaultHook sets function that is called when the Authenticate method
|
||||
// of the parent MockProvider instance is invoked and the hook queue is
|
||||
// empty.
|
||||
func (f *ProviderAuthenticateFunc) SetDefaultHook(hook func(string, string) (*auth.ExternalAccount, error)) {
|
||||
f.defaultHook = hook
|
||||
}
|
||||
|
||||
// PushHook adds a function to the end of hook queue. Each invocation of the
|
||||
// Authenticate method of the parent MockProvider instance invokes the hook
|
||||
// at the front of the queue and discards it. After the queue is empty, the
|
||||
// default hook function is invoked for any future action.
|
||||
func (f *ProviderAuthenticateFunc) PushHook(hook func(string, string) (*auth.ExternalAccount, error)) {
|
||||
f.mutex.Lock()
|
||||
f.hooks = append(f.hooks, hook)
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
// SetDefaultReturn calls SetDefaultHook with a function that returns the
|
||||
// given values.
|
||||
func (f *ProviderAuthenticateFunc) SetDefaultReturn(r0 *auth.ExternalAccount, r1 error) {
|
||||
f.SetDefaultHook(func(string, string) (*auth.ExternalAccount, error) {
|
||||
return r0, r1
|
||||
})
|
||||
}
|
||||
|
||||
// PushReturn calls PushHook with a function that returns the given values.
|
||||
func (f *ProviderAuthenticateFunc) PushReturn(r0 *auth.ExternalAccount, r1 error) {
|
||||
f.PushHook(func(string, string) (*auth.ExternalAccount, error) {
|
||||
return r0, r1
|
||||
})
|
||||
}
|
||||
|
||||
func (f *ProviderAuthenticateFunc) nextHook() func(string, string) (*auth.ExternalAccount, error) {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
if len(f.hooks) == 0 {
|
||||
return f.defaultHook
|
||||
}
|
||||
|
||||
hook := f.hooks[0]
|
||||
f.hooks = f.hooks[1:]
|
||||
return hook
|
||||
}
|
||||
|
||||
func (f *ProviderAuthenticateFunc) appendCall(r0 ProviderAuthenticateFuncCall) {
|
||||
f.mutex.Lock()
|
||||
f.history = append(f.history, r0)
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
// History returns a sequence of ProviderAuthenticateFuncCall objects
|
||||
// describing the invocations of this function.
|
||||
func (f *ProviderAuthenticateFunc) History() []ProviderAuthenticateFuncCall {
|
||||
f.mutex.Lock()
|
||||
history := make([]ProviderAuthenticateFuncCall, len(f.history))
|
||||
copy(history, f.history)
|
||||
f.mutex.Unlock()
|
||||
|
||||
return history
|
||||
}
|
||||
|
||||
// ProviderAuthenticateFuncCall is an object that describes an invocation of
|
||||
// method Authenticate on an instance of MockProvider.
|
||||
type ProviderAuthenticateFuncCall struct {
|
||||
// Arg0 is the value of the 1st argument passed to this method
|
||||
// invocation.
|
||||
Arg0 string
|
||||
// Arg1 is the value of the 2nd argument passed to this method
|
||||
// invocation.
|
||||
Arg1 string
|
||||
// Result0 is the value of the 1st result returned from this method
|
||||
// invocation.
|
||||
Result0 *auth.ExternalAccount
|
||||
// Result1 is the value of the 2nd result returned from this method
|
||||
// invocation.
|
||||
Result1 error
|
||||
}
|
||||
|
||||
// Args returns an interface slice containing the arguments of this
|
||||
// invocation.
|
||||
func (c ProviderAuthenticateFuncCall) Args() []interface{} {
|
||||
return []interface{}{c.Arg0, c.Arg1}
|
||||
}
|
||||
|
||||
// Results returns an interface slice containing the results of this
|
||||
// invocation.
|
||||
func (c ProviderAuthenticateFuncCall) Results() []interface{} {
|
||||
return []interface{}{c.Result0, c.Result1}
|
||||
}
|
||||
|
||||
// ProviderConfigFunc describes the behavior when the Config method of the
|
||||
// parent MockProvider instance is invoked.
|
||||
type ProviderConfigFunc struct {
|
||||
defaultHook func() interface{}
|
||||
hooks []func() interface{}
|
||||
history []ProviderConfigFuncCall
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// Config delegates to the next hook function in the queue and stores the
|
||||
// parameter and result values of this invocation.
|
||||
func (m *MockProvider) Config() interface{} {
|
||||
r0 := m.ConfigFunc.nextHook()()
|
||||
m.ConfigFunc.appendCall(ProviderConfigFuncCall{r0})
|
||||
return r0
|
||||
}
|
||||
|
||||
// SetDefaultHook sets function that is called when the Config method of the
|
||||
// parent MockProvider instance is invoked and the hook queue is empty.
|
||||
func (f *ProviderConfigFunc) SetDefaultHook(hook func() interface{}) {
|
||||
f.defaultHook = hook
|
||||
}
|
||||
|
||||
// PushHook adds a function to the end of hook queue. Each invocation of the
|
||||
// Config method of the parent MockProvider instance invokes the hook at the
|
||||
// front of the queue and discards it. After the queue is empty, the default
|
||||
// hook function is invoked for any future action.
|
||||
func (f *ProviderConfigFunc) PushHook(hook func() interface{}) {
|
||||
f.mutex.Lock()
|
||||
f.hooks = append(f.hooks, hook)
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
// SetDefaultReturn calls SetDefaultHook with a function that returns the
|
||||
// given values.
|
||||
func (f *ProviderConfigFunc) SetDefaultReturn(r0 interface{}) {
|
||||
f.SetDefaultHook(func() interface{} {
|
||||
return r0
|
||||
})
|
||||
}
|
||||
|
||||
// PushReturn calls PushHook with a function that returns the given values.
|
||||
func (f *ProviderConfigFunc) PushReturn(r0 interface{}) {
|
||||
f.PushHook(func() interface{} {
|
||||
return r0
|
||||
})
|
||||
}
|
||||
|
||||
func (f *ProviderConfigFunc) nextHook() func() interface{} {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
if len(f.hooks) == 0 {
|
||||
return f.defaultHook
|
||||
}
|
||||
|
||||
hook := f.hooks[0]
|
||||
f.hooks = f.hooks[1:]
|
||||
return hook
|
||||
}
|
||||
|
||||
func (f *ProviderConfigFunc) appendCall(r0 ProviderConfigFuncCall) {
|
||||
f.mutex.Lock()
|
||||
f.history = append(f.history, r0)
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
// History returns a sequence of ProviderConfigFuncCall objects describing
|
||||
// the invocations of this function.
|
||||
func (f *ProviderConfigFunc) History() []ProviderConfigFuncCall {
|
||||
f.mutex.Lock()
|
||||
history := make([]ProviderConfigFuncCall, len(f.history))
|
||||
copy(history, f.history)
|
||||
f.mutex.Unlock()
|
||||
|
||||
return history
|
||||
}
|
||||
|
||||
// ProviderConfigFuncCall is an object that describes an invocation of
|
||||
// method Config on an instance of MockProvider.
|
||||
type ProviderConfigFuncCall struct {
|
||||
// Result0 is the value of the 1st result returned from this method
|
||||
// invocation.
|
||||
Result0 interface{}
|
||||
}
|
||||
|
||||
// Args returns an interface slice containing the arguments of this
|
||||
// invocation.
|
||||
func (c ProviderConfigFuncCall) Args() []interface{} {
|
||||
return []interface{}{}
|
||||
}
|
||||
|
||||
// Results returns an interface slice containing the results of this
|
||||
// invocation.
|
||||
func (c ProviderConfigFuncCall) Results() []interface{} {
|
||||
return []interface{}{c.Result0}
|
||||
}
|
||||
|
||||
// ProviderHasTLSFunc describes the behavior when the HasTLS method of the
|
||||
// parent MockProvider instance is invoked.
|
||||
type ProviderHasTLSFunc struct {
|
||||
defaultHook func() bool
|
||||
hooks []func() bool
|
||||
history []ProviderHasTLSFuncCall
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// HasTLS delegates to the next hook function in the queue and stores the
|
||||
// parameter and result values of this invocation.
|
||||
func (m *MockProvider) HasTLS() bool {
|
||||
r0 := m.HasTLSFunc.nextHook()()
|
||||
m.HasTLSFunc.appendCall(ProviderHasTLSFuncCall{r0})
|
||||
return r0
|
||||
}
|
||||
|
||||
// SetDefaultHook sets function that is called when the HasTLS method of the
|
||||
// parent MockProvider instance is invoked and the hook queue is empty.
|
||||
func (f *ProviderHasTLSFunc) SetDefaultHook(hook func() bool) {
|
||||
f.defaultHook = hook
|
||||
}
|
||||
|
||||
// PushHook adds a function to the end of hook queue. Each invocation of the
|
||||
// HasTLS method of the parent MockProvider instance invokes the hook at the
|
||||
// front of the queue and discards it. After the queue is empty, the default
|
||||
// hook function is invoked for any future action.
|
||||
func (f *ProviderHasTLSFunc) PushHook(hook func() bool) {
|
||||
f.mutex.Lock()
|
||||
f.hooks = append(f.hooks, hook)
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
// SetDefaultReturn calls SetDefaultHook with a function that returns the
|
||||
// given values.
|
||||
func (f *ProviderHasTLSFunc) SetDefaultReturn(r0 bool) {
|
||||
f.SetDefaultHook(func() bool {
|
||||
return r0
|
||||
})
|
||||
}
|
||||
|
||||
// PushReturn calls PushHook with a function that returns the given values.
|
||||
func (f *ProviderHasTLSFunc) PushReturn(r0 bool) {
|
||||
f.PushHook(func() bool {
|
||||
return r0
|
||||
})
|
||||
}
|
||||
|
||||
func (f *ProviderHasTLSFunc) nextHook() func() bool {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
if len(f.hooks) == 0 {
|
||||
return f.defaultHook
|
||||
}
|
||||
|
||||
hook := f.hooks[0]
|
||||
f.hooks = f.hooks[1:]
|
||||
return hook
|
||||
}
|
||||
|
||||
func (f *ProviderHasTLSFunc) appendCall(r0 ProviderHasTLSFuncCall) {
|
||||
f.mutex.Lock()
|
||||
f.history = append(f.history, r0)
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
// History returns a sequence of ProviderHasTLSFuncCall objects describing
|
||||
// the invocations of this function.
|
||||
func (f *ProviderHasTLSFunc) History() []ProviderHasTLSFuncCall {
|
||||
f.mutex.Lock()
|
||||
history := make([]ProviderHasTLSFuncCall, len(f.history))
|
||||
copy(history, f.history)
|
||||
f.mutex.Unlock()
|
||||
|
||||
return history
|
||||
}
|
||||
|
||||
// ProviderHasTLSFuncCall is an object that describes an invocation of
|
||||
// method HasTLS on an instance of MockProvider.
|
||||
type ProviderHasTLSFuncCall struct {
|
||||
// Result0 is the value of the 1st result returned from this method
|
||||
// invocation.
|
||||
Result0 bool
|
||||
}
|
||||
|
||||
// Args returns an interface slice containing the arguments of this
|
||||
// invocation.
|
||||
func (c ProviderHasTLSFuncCall) Args() []interface{} {
|
||||
return []interface{}{}
|
||||
}
|
||||
|
||||
// Results returns an interface slice containing the results of this
|
||||
// invocation.
|
||||
func (c ProviderHasTLSFuncCall) Results() []interface{} {
|
||||
return []interface{}{c.Result0}
|
||||
}
|
||||
|
||||
// ProviderSkipTLSVerifyFunc describes the behavior when the SkipTLSVerify
|
||||
// method of the parent MockProvider instance is invoked.
|
||||
type ProviderSkipTLSVerifyFunc struct {
|
||||
defaultHook func() bool
|
||||
hooks []func() bool
|
||||
history []ProviderSkipTLSVerifyFuncCall
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// SkipTLSVerify delegates to the next hook function in the queue and stores
|
||||
// the parameter and result values of this invocation.
|
||||
func (m *MockProvider) SkipTLSVerify() bool {
|
||||
r0 := m.SkipTLSVerifyFunc.nextHook()()
|
||||
m.SkipTLSVerifyFunc.appendCall(ProviderSkipTLSVerifyFuncCall{r0})
|
||||
return r0
|
||||
}
|
||||
|
||||
// SetDefaultHook sets function that is called when the SkipTLSVerify method
|
||||
// of the parent MockProvider instance is invoked and the hook queue is
|
||||
// empty.
|
||||
func (f *ProviderSkipTLSVerifyFunc) SetDefaultHook(hook func() bool) {
|
||||
f.defaultHook = hook
|
||||
}
|
||||
|
||||
// PushHook adds a function to the end of hook queue. Each invocation of the
|
||||
// SkipTLSVerify method of the parent MockProvider instance invokes the hook
|
||||
// at the front of the queue and discards it. After the queue is empty, the
|
||||
// default hook function is invoked for any future action.
|
||||
func (f *ProviderSkipTLSVerifyFunc) PushHook(hook func() bool) {
|
||||
f.mutex.Lock()
|
||||
f.hooks = append(f.hooks, hook)
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
// SetDefaultReturn calls SetDefaultHook with a function that returns the
|
||||
// given values.
|
||||
func (f *ProviderSkipTLSVerifyFunc) SetDefaultReturn(r0 bool) {
|
||||
f.SetDefaultHook(func() bool {
|
||||
return r0
|
||||
})
|
||||
}
|
||||
|
||||
// PushReturn calls PushHook with a function that returns the given values.
|
||||
func (f *ProviderSkipTLSVerifyFunc) PushReturn(r0 bool) {
|
||||
f.PushHook(func() bool {
|
||||
return r0
|
||||
})
|
||||
}
|
||||
|
||||
func (f *ProviderSkipTLSVerifyFunc) nextHook() func() bool {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
if len(f.hooks) == 0 {
|
||||
return f.defaultHook
|
||||
}
|
||||
|
||||
hook := f.hooks[0]
|
||||
f.hooks = f.hooks[1:]
|
||||
return hook
|
||||
}
|
||||
|
||||
func (f *ProviderSkipTLSVerifyFunc) appendCall(r0 ProviderSkipTLSVerifyFuncCall) {
|
||||
f.mutex.Lock()
|
||||
f.history = append(f.history, r0)
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
// History returns a sequence of ProviderSkipTLSVerifyFuncCall objects
|
||||
// describing the invocations of this function.
|
||||
func (f *ProviderSkipTLSVerifyFunc) History() []ProviderSkipTLSVerifyFuncCall {
|
||||
f.mutex.Lock()
|
||||
history := make([]ProviderSkipTLSVerifyFuncCall, len(f.history))
|
||||
copy(history, f.history)
|
||||
f.mutex.Unlock()
|
||||
|
||||
return history
|
||||
}
|
||||
|
||||
// ProviderSkipTLSVerifyFuncCall is an object that describes an invocation
|
||||
// of method SkipTLSVerify on an instance of MockProvider.
|
||||
type ProviderSkipTLSVerifyFuncCall struct {
|
||||
// Result0 is the value of the 1st result returned from this method
|
||||
// invocation.
|
||||
Result0 bool
|
||||
}
|
||||
|
||||
// Args returns an interface slice containing the arguments of this
|
||||
// invocation.
|
||||
func (c ProviderSkipTLSVerifyFuncCall) Args() []interface{} {
|
||||
return []interface{}{}
|
||||
}
|
||||
|
||||
// Results returns an interface slice containing the results of this
|
||||
// invocation.
|
||||
func (c ProviderSkipTLSVerifyFuncCall) Results() []interface{} {
|
||||
return []interface{}{c.Result0}
|
||||
}
|
||||
|
||||
// ProviderUseTLSFunc describes the behavior when the UseTLS method of the
|
||||
// parent MockProvider instance is invoked.
|
||||
type ProviderUseTLSFunc struct {
|
||||
defaultHook func() bool
|
||||
hooks []func() bool
|
||||
history []ProviderUseTLSFuncCall
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// UseTLS delegates to the next hook function in the queue and stores the
|
||||
// parameter and result values of this invocation.
|
||||
func (m *MockProvider) UseTLS() bool {
|
||||
r0 := m.UseTLSFunc.nextHook()()
|
||||
m.UseTLSFunc.appendCall(ProviderUseTLSFuncCall{r0})
|
||||
return r0
|
||||
}
|
||||
|
||||
// SetDefaultHook sets function that is called when the UseTLS method of the
|
||||
// parent MockProvider instance is invoked and the hook queue is empty.
|
||||
func (f *ProviderUseTLSFunc) SetDefaultHook(hook func() bool) {
|
||||
f.defaultHook = hook
|
||||
}
|
||||
|
||||
// PushHook adds a function to the end of hook queue. Each invocation of the
|
||||
// UseTLS method of the parent MockProvider instance invokes the hook at the
|
||||
// front of the queue and discards it. After the queue is empty, the default
|
||||
// hook function is invoked for any future action.
|
||||
func (f *ProviderUseTLSFunc) PushHook(hook func() bool) {
|
||||
f.mutex.Lock()
|
||||
f.hooks = append(f.hooks, hook)
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
// SetDefaultReturn calls SetDefaultHook with a function that returns the
|
||||
// given values.
|
||||
func (f *ProviderUseTLSFunc) SetDefaultReturn(r0 bool) {
|
||||
f.SetDefaultHook(func() bool {
|
||||
return r0
|
||||
})
|
||||
}
|
||||
|
||||
// PushReturn calls PushHook with a function that returns the given values.
|
||||
func (f *ProviderUseTLSFunc) PushReturn(r0 bool) {
|
||||
f.PushHook(func() bool {
|
||||
return r0
|
||||
})
|
||||
}
|
||||
|
||||
func (f *ProviderUseTLSFunc) nextHook() func() bool {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
if len(f.hooks) == 0 {
|
||||
return f.defaultHook
|
||||
}
|
||||
|
||||
hook := f.hooks[0]
|
||||
f.hooks = f.hooks[1:]
|
||||
return hook
|
||||
}
|
||||
|
||||
func (f *ProviderUseTLSFunc) appendCall(r0 ProviderUseTLSFuncCall) {
|
||||
f.mutex.Lock()
|
||||
f.history = append(f.history, r0)
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
// History returns a sequence of ProviderUseTLSFuncCall objects describing
|
||||
// the invocations of this function.
|
||||
func (f *ProviderUseTLSFunc) History() []ProviderUseTLSFuncCall {
|
||||
f.mutex.Lock()
|
||||
history := make([]ProviderUseTLSFuncCall, len(f.history))
|
||||
copy(history, f.history)
|
||||
f.mutex.Unlock()
|
||||
|
||||
return history
|
||||
}
|
||||
|
||||
// ProviderUseTLSFuncCall is an object that describes an invocation of
|
||||
// method UseTLS on an instance of MockProvider.
|
||||
type ProviderUseTLSFuncCall struct {
|
||||
// Result0 is the value of the 1st result returned from this method
|
||||
// invocation.
|
||||
Result0 bool
|
||||
}
|
||||
|
||||
// Args returns an interface slice containing the arguments of this
|
||||
// invocation.
|
||||
func (c ProviderUseTLSFuncCall) Args() []interface{} {
|
||||
return []interface{}{}
|
||||
}
|
||||
|
||||
// Results returns an interface slice containing the results of this
|
||||
// invocation.
|
||||
func (c ProviderUseTLSFuncCall) Results() []interface{} {
|
||||
return []interface{}{c.Result0}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -224,7 +224,7 @@ func GetStatistic(ctx context.Context) (stats Statistic) {
|
|||
stats.Counter.Follow, _ = x.Count(new(Follow))
|
||||
stats.Counter.Mirror, _ = x.Count(new(Mirror))
|
||||
stats.Counter.Release, _ = x.Count(new(Release))
|
||||
stats.Counter.LoginSource = LoginSources.Count(ctx)
|
||||
stats.Counter.LoginSource = Handle.LoginSources().Count(ctx)
|
||||
stats.Counter.Webhook, _ = x.Count(new(Webhook))
|
||||
stats.Counter.Milestone, _ = x.Count(new(Milestone))
|
||||
stats.Counter.Label, _ = x.Count(new(Label))
|
||||
|
|
|
@ -185,7 +185,7 @@ func (s *usersStore) Authenticate(ctx context.Context, login, password string, l
|
|||
|
||||
user := new(User)
|
||||
err := query.First(user).Error
|
||||
if err != nil && err != gorm.ErrRecordNotFound {
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.Wrap(err, "get user")
|
||||
}
|
||||
|
||||
|
@ -221,7 +221,7 @@ func (s *usersStore) Authenticate(ctx context.Context, login, password string, l
|
|||
createNewUser = true
|
||||
}
|
||||
|
||||
source, err := LoginSources.GetByID(ctx, authSourceID)
|
||||
source, err := newLoginSourcesStore(s.DB, loadedLoginSourceFilesStore).GetByID(ctx, authSourceID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get login source")
|
||||
}
|
||||
|
|
|
@ -175,17 +175,19 @@ func usersAuthenticate(t *testing.T, ctx context.Context, db *usersStore) {
|
|||
})
|
||||
|
||||
t.Run("via login source", func(t *testing.T) {
|
||||
mockLoginSources := NewMockLoginSourcesStore()
|
||||
mockLoginSources.GetByIDFunc.SetDefaultHook(func(ctx context.Context, id int64) (*LoginSource, error) {
|
||||
mockProvider := NewMockProvider()
|
||||
mockProvider.AuthenticateFunc.SetDefaultReturn(&auth.ExternalAccount{}, nil)
|
||||
s := &LoginSource{
|
||||
IsActived: true,
|
||||
Provider: mockProvider,
|
||||
}
|
||||
return s, nil
|
||||
})
|
||||
setMockLoginSourcesStore(t, mockLoginSources)
|
||||
loginSourcesStore := newLoginSourcesStore(db.DB, NewMockLoginSourceFilesStore())
|
||||
loginSource, err := loginSourcesStore.Create(
|
||||
ctx,
|
||||
CreateLoginSourceOptions{
|
||||
Type: auth.Mock,
|
||||
Name: "mock-1",
|
||||
Activated: true,
|
||||
Config: mockProviderConfig{
|
||||
ExternalAccount: &auth.ExternalAccount{},
|
||||
},
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
bob, err := db.Create(ctx, "bob", "bob@example.com",
|
||||
CreateUserOptions{
|
||||
|
@ -195,31 +197,30 @@ func usersAuthenticate(t *testing.T, ctx context.Context, db *usersStore) {
|
|||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err := db.Authenticate(ctx, bob.Email, password, 1)
|
||||
user, err := db.Authenticate(ctx, bob.Email, password, loginSource.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, bob.Name, user.Name)
|
||||
})
|
||||
|
||||
t.Run("new user via login source", func(t *testing.T) {
|
||||
mockLoginSources := NewMockLoginSourcesStore()
|
||||
mockLoginSources.GetByIDFunc.SetDefaultHook(func(ctx context.Context, id int64) (*LoginSource, error) {
|
||||
mockProvider := NewMockProvider()
|
||||
mockProvider.AuthenticateFunc.SetDefaultReturn(
|
||||
&auth.ExternalAccount{
|
||||
Name: "cindy",
|
||||
Email: "cindy@example.com",
|
||||
loginSourcesStore := newLoginSourcesStore(db.DB, NewMockLoginSourceFilesStore())
|
||||
loginSource, err := loginSourcesStore.Create(
|
||||
ctx,
|
||||
CreateLoginSourceOptions{
|
||||
Type: auth.Mock,
|
||||
Name: "mock-2",
|
||||
Activated: true,
|
||||
Config: mockProviderConfig{
|
||||
ExternalAccount: &auth.ExternalAccount{
|
||||
Name: "cindy",
|
||||
Email: "cindy@example.com",
|
||||
},
|
||||
},
|
||||
nil,
|
||||
)
|
||||
s := &LoginSource{
|
||||
IsActived: true,
|
||||
Provider: mockProvider,
|
||||
}
|
||||
return s, nil
|
||||
})
|
||||
setMockLoginSourcesStore(t, mockLoginSources)
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err := db.Authenticate(ctx, "cindy", password, 1)
|
||||
user, err := db.Authenticate(ctx, "cindy", password, loginSource.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "cindy", user.Name)
|
||||
|
||||
|
|
|
@ -35,13 +35,13 @@ func Authentications(c *context.Context) {
|
|||
c.PageIs("AdminAuthentications")
|
||||
|
||||
var err error
|
||||
c.Data["Sources"], err = database.LoginSources.List(c.Req.Context(), database.ListLoginSourceOptions{})
|
||||
c.Data["Sources"], err = database.Handle.LoginSources().List(c.Req.Context(), database.ListLoginSourceOptions{})
|
||||
if err != nil {
|
||||
c.Error(err, "list login sources")
|
||||
return
|
||||
}
|
||||
|
||||
c.Data["Total"] = database.LoginSources.Count(c.Req.Context())
|
||||
c.Data["Total"] = database.Handle.LoginSources().Count(c.Req.Context())
|
||||
c.Success(AUTHS)
|
||||
}
|
||||
|
||||
|
@ -159,7 +159,7 @@ func NewAuthSourcePost(c *context.Context, f form.Authentication) {
|
|||
return
|
||||
}
|
||||
|
||||
source, err := database.LoginSources.Create(c.Req.Context(),
|
||||
source, err := database.Handle.LoginSources().Create(c.Req.Context(),
|
||||
database.CreateLoginSourceOptions{
|
||||
Type: auth.Type(f.Type),
|
||||
Name: f.Name,
|
||||
|
@ -179,7 +179,7 @@ func NewAuthSourcePost(c *context.Context, f form.Authentication) {
|
|||
}
|
||||
|
||||
if source.IsDefault {
|
||||
err = database.LoginSources.ResetNonDefault(c.Req.Context(), source)
|
||||
err = database.Handle.LoginSources().ResetNonDefault(c.Req.Context(), source)
|
||||
if err != nil {
|
||||
c.Error(err, "reset non-default login sources")
|
||||
return
|
||||
|
@ -200,7 +200,7 @@ func EditAuthSource(c *context.Context) {
|
|||
c.Data["SecurityProtocols"] = securityProtocols
|
||||
c.Data["SMTPAuths"] = smtp.AuthTypes
|
||||
|
||||
source, err := database.LoginSources.GetByID(c.Req.Context(), c.ParamsInt64(":authid"))
|
||||
source, err := database.Handle.LoginSources().GetByID(c.Req.Context(), c.ParamsInt64(":authid"))
|
||||
if err != nil {
|
||||
c.Error(err, "get login source by ID")
|
||||
return
|
||||
|
@ -218,7 +218,7 @@ func EditAuthSourcePost(c *context.Context, f form.Authentication) {
|
|||
|
||||
c.Data["SMTPAuths"] = smtp.AuthTypes
|
||||
|
||||
source, err := database.LoginSources.GetByID(c.Req.Context(), c.ParamsInt64(":authid"))
|
||||
source, err := database.Handle.LoginSources().GetByID(c.Req.Context(), c.ParamsInt64(":authid"))
|
||||
if err != nil {
|
||||
c.Error(err, "get login source by ID")
|
||||
return
|
||||
|
@ -257,13 +257,13 @@ func EditAuthSourcePost(c *context.Context, f form.Authentication) {
|
|||
source.IsActived = f.IsActive
|
||||
source.IsDefault = f.IsDefault
|
||||
source.Provider = provider
|
||||
if err := database.LoginSources.Save(c.Req.Context(), source); err != nil {
|
||||
if err := database.Handle.LoginSources().Save(c.Req.Context(), source); err != nil {
|
||||
c.Error(err, "update login source")
|
||||
return
|
||||
}
|
||||
|
||||
if source.IsDefault {
|
||||
err = database.LoginSources.ResetNonDefault(c.Req.Context(), source)
|
||||
err = database.Handle.LoginSources().ResetNonDefault(c.Req.Context(), source)
|
||||
if err != nil {
|
||||
c.Error(err, "reset non-default login sources")
|
||||
return
|
||||
|
@ -278,7 +278,7 @@ func EditAuthSourcePost(c *context.Context, f form.Authentication) {
|
|||
|
||||
func DeleteAuthSource(c *context.Context) {
|
||||
id := c.ParamsInt64(":authid")
|
||||
if err := database.LoginSources.DeleteByID(c.Req.Context(), id); err != nil {
|
||||
if err := database.Handle.LoginSources().DeleteByID(c.Req.Context(), id); err != nil {
|
||||
if database.IsErrLoginSourceInUse(err) {
|
||||
c.Flash.Error(c.Tr("admin.auths.still_in_used"))
|
||||
} else {
|
||||
|
|
|
@ -46,7 +46,7 @@ func NewUser(c *context.Context) {
|
|||
|
||||
c.Data["login_type"] = "0-0"
|
||||
|
||||
sources, err := database.LoginSources.List(c.Req.Context(), database.ListLoginSourceOptions{})
|
||||
sources, err := database.Handle.LoginSources().List(c.Req.Context(), database.ListLoginSourceOptions{})
|
||||
if err != nil {
|
||||
c.Error(err, "list login sources")
|
||||
return
|
||||
|
@ -62,7 +62,7 @@ func NewUserPost(c *context.Context, f form.AdminCrateUser) {
|
|||
c.Data["PageIsAdmin"] = true
|
||||
c.Data["PageIsAdminUsers"] = true
|
||||
|
||||
sources, err := database.LoginSources.List(c.Req.Context(), database.ListLoginSourceOptions{})
|
||||
sources, err := database.Handle.LoginSources().List(c.Req.Context(), database.ListLoginSourceOptions{})
|
||||
if err != nil {
|
||||
c.Error(err, "list login sources")
|
||||
return
|
||||
|
@ -125,7 +125,7 @@ func prepareUserInfo(c *context.Context) *database.User {
|
|||
c.Data["User"] = u
|
||||
|
||||
if u.LoginSource > 0 {
|
||||
c.Data["LoginSource"], err = database.LoginSources.GetByID(c.Req.Context(), u.LoginSource)
|
||||
c.Data["LoginSource"], err = database.Handle.LoginSources().GetByID(c.Req.Context(), u.LoginSource)
|
||||
if err != nil {
|
||||
c.Error(err, "get login source by ID")
|
||||
return nil
|
||||
|
@ -134,7 +134,7 @@ func prepareUserInfo(c *context.Context) *database.User {
|
|||
c.Data["LoginSource"] = &database.LoginSource{}
|
||||
}
|
||||
|
||||
sources, err := database.LoginSources.List(c.Req.Context(), database.ListLoginSourceOptions{})
|
||||
sources, err := database.Handle.LoginSources().List(c.Req.Context(), database.ListLoginSourceOptions{})
|
||||
if err != nil {
|
||||
c.Error(err, "list login sources")
|
||||
return nil
|
||||
|
|
|
@ -22,7 +22,7 @@ func parseLoginSource(c *context.APIContext, sourceID int64) {
|
|||
return
|
||||
}
|
||||
|
||||
_, err := database.LoginSources.GetByID(c.Req.Context(), sourceID)
|
||||
_, err := database.Handle.LoginSources().GetByID(c.Req.Context(), sourceID)
|
||||
if err != nil {
|
||||
if database.IsErrLoginSourceNotExist(err) {
|
||||
c.ErrorStatus(http.StatusUnprocessableEntity, err)
|
||||
|
|
|
@ -106,7 +106,7 @@ func Login(c *context.Context) {
|
|||
}
|
||||
|
||||
// Display normal login page
|
||||
loginSources, err := database.LoginSources.List(c.Req.Context(), database.ListLoginSourceOptions{OnlyActivated: true})
|
||||
loginSources, err := database.Handle.LoginSources().List(c.Req.Context(), database.ListLoginSourceOptions{OnlyActivated: true})
|
||||
if err != nil {
|
||||
c.Error(err, "list activated login sources")
|
||||
return
|
||||
|
@ -153,7 +153,7 @@ func afterLogin(c *context.Context, u *database.User, remember bool) {
|
|||
func LoginPost(c *context.Context, f form.SignIn) {
|
||||
c.Title("sign_in")
|
||||
|
||||
loginSources, err := database.LoginSources.List(c.Req.Context(), database.ListLoginSourceOptions{OnlyActivated: true})
|
||||
loginSources, err := database.Handle.LoginSources().List(c.Req.Context(), database.ListLoginSourceOptions{OnlyActivated: true})
|
||||
if err != nil {
|
||||
c.Error(err, "list activated login sources")
|
||||
return
|
||||
|
|
|
@ -25,10 +25,10 @@ mocks:
|
|||
sources:
|
||||
- path: gogs.io/gogs/internal/database
|
||||
interfaces:
|
||||
- LoginSourcesStore
|
||||
- LoginSourceFilesStore
|
||||
- LoginSourceFileStore
|
||||
- loginSourceFileStore
|
||||
- loginSourceFilesStore
|
||||
- filename: internal/database/mocks_gen.go
|
||||
sources:
|
||||
- path: gogs.io/gogs/internal/auth
|
||||
interfaces:
|
||||
- Provider
|
||||
|
|
Loading…
Reference in New Issue