internal/database: consistently use Store and s as receiver (#7669)

This commit is contained in:
Joe Chen 2024-02-19 20:00:13 -05:00 committed by GitHub
parent dfe27ad556
commit 917c14f2ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 371 additions and 371 deletions

View File

@ -74,9 +74,9 @@ func (t *AccessToken) AfterFind(tx *gorm.DB) error {
return nil
}
var _ AccessTokensStore = (*accessTokens)(nil)
var _ AccessTokensStore = (*accessTokensStore)(nil)
type accessTokens struct {
type accessTokensStore struct {
*gorm.DB
}
@ -93,8 +93,8 @@ func (err ErrAccessTokenAlreadyExist) Error() string {
return fmt.Sprintf("access token already exists: %v", err.args)
}
func (db *accessTokens) Create(ctx context.Context, userID int64, name string) (*AccessToken, error) {
err := db.WithContext(ctx).Where("uid = ? AND name = ?", userID, name).First(new(AccessToken)).Error
func (s *accessTokensStore) Create(ctx context.Context, userID int64, name string) (*AccessToken, error) {
err := s.WithContext(ctx).Where("uid = ? AND name = ?", userID, name).First(new(AccessToken)).Error
if err == nil {
return nil, ErrAccessTokenAlreadyExist{args: errutil.Args{"userID": userID, "name": name}}
} else if err != gorm.ErrRecordNotFound {
@ -110,7 +110,7 @@ func (db *accessTokens) Create(ctx context.Context, userID int64, name string) (
Sha1: sha256[:40], // To pass the column unique constraint, keep the length of SHA1.
SHA256: sha256,
}
if err = db.WithContext(ctx).Create(accessToken).Error; err != nil {
if err = s.WithContext(ctx).Create(accessToken).Error; err != nil {
return nil, err
}
@ -119,8 +119,8 @@ func (db *accessTokens) Create(ctx context.Context, userID int64, name string) (
return accessToken, nil
}
func (db *accessTokens) DeleteByID(ctx context.Context, userID, id int64) error {
return db.WithContext(ctx).Where("id = ? AND uid = ?", id, userID).Delete(new(AccessToken)).Error
func (s *accessTokensStore) DeleteByID(ctx context.Context, userID, id int64) error {
return s.WithContext(ctx).Where("id = ? AND uid = ?", id, userID).Delete(new(AccessToken)).Error
}
var _ errutil.NotFound = (*ErrAccessTokenNotExist)(nil)
@ -144,7 +144,7 @@ func (ErrAccessTokenNotExist) NotFound() bool {
return true
}
func (db *accessTokens) GetBySHA1(ctx context.Context, sha1 string) (*AccessToken, error) {
func (s *accessTokensStore) GetBySHA1(ctx context.Context, sha1 string) (*AccessToken, error) {
// No need to waste a query for an empty SHA1.
if sha1 == "" {
return nil, ErrAccessTokenNotExist{args: errutil.Args{"sha": sha1}}
@ -152,7 +152,7 @@ func (db *accessTokens) GetBySHA1(ctx context.Context, sha1 string) (*AccessToke
sha256 := cryptoutil.SHA256(sha1)
token := new(AccessToken)
err := db.WithContext(ctx).Where("sha256 = ?", sha256).First(token).Error
err := s.WithContext(ctx).Where("sha256 = ?", sha256).First(token).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, ErrAccessTokenNotExist{args: errutil.Args{"sha": sha1}}
@ -162,15 +162,15 @@ func (db *accessTokens) GetBySHA1(ctx context.Context, sha1 string) (*AccessToke
return token, nil
}
func (db *accessTokens) List(ctx context.Context, userID int64) ([]*AccessToken, error) {
func (s *accessTokensStore) List(ctx context.Context, userID int64) ([]*AccessToken, error) {
var tokens []*AccessToken
return tokens, db.WithContext(ctx).Where("uid = ?", userID).Order("id ASC").Find(&tokens).Error
return tokens, s.WithContext(ctx).Where("uid = ?", userID).Order("id ASC").Find(&tokens).Error
}
func (db *accessTokens) Touch(ctx context.Context, id int64) error {
return db.WithContext(ctx).
func (s *accessTokensStore) Touch(ctx context.Context, id int64) error {
return s.WithContext(ctx).
Model(new(AccessToken)).
Where("id = ?", id).
UpdateColumn("updated_unix", db.NowFunc().Unix()).
UpdateColumn("updated_unix", s.NowFunc().Unix()).
Error
}

View File

@ -98,13 +98,13 @@ func TestAccessTokens(t *testing.T) {
t.Parallel()
ctx := context.Background()
db := &accessTokens{
DB: newTestDB(t, "accessTokens"),
db := &accessTokensStore{
DB: newTestDB(t, "accessTokensStore"),
}
for _, tc := range []struct {
name string
test func(t *testing.T, ctx context.Context, db *accessTokens)
test func(t *testing.T, ctx context.Context, db *accessTokensStore)
}{
{"Create", accessTokensCreate},
{"DeleteByID", accessTokensDeleteByID},
@ -125,7 +125,7 @@ func TestAccessTokens(t *testing.T) {
}
}
func accessTokensCreate(t *testing.T, ctx context.Context, db *accessTokens) {
func accessTokensCreate(t *testing.T, ctx context.Context, db *accessTokensStore) {
// Create first access token with name "Test"
token, err := db.Create(ctx, 1, "Test")
require.NoError(t, err)
@ -150,7 +150,7 @@ func accessTokensCreate(t *testing.T, ctx context.Context, db *accessTokens) {
assert.Equal(t, wantErr, err)
}
func accessTokensDeleteByID(t *testing.T, ctx context.Context, db *accessTokens) {
func accessTokensDeleteByID(t *testing.T, ctx context.Context, db *accessTokensStore) {
// Create an access token with name "Test"
token, err := db.Create(ctx, 1, "Test")
require.NoError(t, err)
@ -177,7 +177,7 @@ func accessTokensDeleteByID(t *testing.T, ctx context.Context, db *accessTokens)
assert.Equal(t, wantErr, err)
}
func accessTokensGetBySHA(t *testing.T, ctx context.Context, db *accessTokens) {
func accessTokensGetBySHA(t *testing.T, ctx context.Context, db *accessTokensStore) {
// Create an access token with name "Test"
token, err := db.Create(ctx, 1, "Test")
require.NoError(t, err)
@ -196,7 +196,7 @@ func accessTokensGetBySHA(t *testing.T, ctx context.Context, db *accessTokens) {
assert.Equal(t, wantErr, err)
}
func accessTokensList(t *testing.T, ctx context.Context, db *accessTokens) {
func accessTokensList(t *testing.T, ctx context.Context, db *accessTokensStore) {
// Create two access tokens for user 1
_, err := db.Create(ctx, 1, "user1_1")
require.NoError(t, err)
@ -219,7 +219,7 @@ func accessTokensList(t *testing.T, ctx context.Context, db *accessTokens) {
assert.Equal(t, "user1_2", tokens[1].Name)
}
func accessTokensTouch(t *testing.T, ctx context.Context, db *accessTokens) {
func accessTokensTouch(t *testing.T, ctx context.Context, db *accessTokensStore) {
// Create an access token with name "Test"
token, err := db.Create(ctx, 1, "Test")
require.NoError(t, err)

View File

@ -70,19 +70,19 @@ type ActionsStore interface {
var Actions ActionsStore
var _ ActionsStore = (*actions)(nil)
var _ ActionsStore = (*actionsStore)(nil)
type actions struct {
type actionsStore struct {
*gorm.DB
}
// NewActionsStore returns a persistent interface for actions with given
// database connection.
func NewActionsStore(db *gorm.DB) ActionsStore {
return &actions{DB: db}
return &actionsStore{DB: db}
}
func (db *actions) listByOrganization(ctx context.Context, orgID, actorID, afterID int64) *gorm.DB {
func (s *actionsStore) listByOrganization(ctx context.Context, orgID, actorID, afterID int64) *gorm.DB {
/*
Equivalent SQL for PostgreSQL:
@ -102,18 +102,18 @@ func (db *actions) listByOrganization(ctx context.Context, orgID, actorID, after
ORDER BY id DESC
LIMIT @limit
*/
return db.WithContext(ctx).
return s.WithContext(ctx).
Where("user_id = ?", orgID).
Where(db.
Where(s.
// Not apply when afterID is not given
Where("?", afterID <= 0).
Or("id < ?", afterID),
).
Where("repo_id IN (?)", db.
Where("repo_id IN (?)", s.
Select("repository.id").
Table("repository").
Joins("JOIN team_repo ON repository.id = team_repo.repo_id").
Where("team_repo.team_id IN (?)", db.
Where("team_repo.team_id IN (?)", s.
Select("team_id").
Table("team_user").
Where("team_user.org_id = ? AND uid = ?", orgID, actorID),
@ -124,12 +124,12 @@ func (db *actions) listByOrganization(ctx context.Context, orgID, actorID, after
Order("id DESC")
}
func (db *actions) ListByOrganization(ctx context.Context, orgID, actorID, afterID int64) ([]*Action, error) {
func (s *actionsStore) ListByOrganization(ctx context.Context, orgID, actorID, afterID int64) ([]*Action, error) {
actions := make([]*Action, 0, conf.UI.User.NewsFeedPagingNum)
return actions, db.listByOrganization(ctx, orgID, actorID, afterID).Find(&actions).Error
return actions, s.listByOrganization(ctx, orgID, actorID, afterID).Find(&actions).Error
}
func (db *actions) listByUser(ctx context.Context, userID, actorID, afterID int64, isProfile bool) *gorm.DB {
func (s *actionsStore) listByUser(ctx context.Context, userID, actorID, afterID int64, isProfile bool) *gorm.DB {
/*
Equivalent SQL for PostgreSQL:
@ -141,14 +141,14 @@ func (db *actions) listByUser(ctx context.Context, userID, actorID, afterID int6
ORDER BY id DESC
LIMIT @limit
*/
return db.WithContext(ctx).
return s.WithContext(ctx).
Where("user_id = ?", userID).
Where(db.
Where(s.
// Not apply when afterID is not given
Where("?", afterID <= 0).
Or("id < ?", afterID),
).
Where(db.
Where(s.
// Not apply when in not profile page or the user is viewing own profile
Where("?", !isProfile || actorID == userID).
Or("is_private = ? AND act_user_id = ?", false, userID),
@ -157,14 +157,14 @@ func (db *actions) listByUser(ctx context.Context, userID, actorID, afterID int6
Order("id DESC")
}
func (db *actions) ListByUser(ctx context.Context, userID, actorID, afterID int64, isProfile bool) ([]*Action, error) {
func (s *actionsStore) ListByUser(ctx context.Context, userID, actorID, afterID int64, isProfile bool) ([]*Action, error) {
actions := make([]*Action, 0, conf.UI.User.NewsFeedPagingNum)
return actions, db.listByUser(ctx, userID, actorID, afterID, isProfile).Find(&actions).Error
return actions, s.listByUser(ctx, userID, actorID, afterID, isProfile).Find(&actions).Error
}
// notifyWatchers creates rows in action table for watchers who are able to see the action.
func (db *actions) notifyWatchers(ctx context.Context, act *Action) error {
watches, err := NewReposStore(db.DB).ListWatches(ctx, act.RepoID)
func (s *actionsStore) notifyWatchers(ctx context.Context, act *Action) error {
watches, err := NewReposStore(s.DB).ListWatches(ctx, act.RepoID)
if err != nil {
return errors.Wrap(err, "list watches")
}
@ -187,16 +187,16 @@ func (db *actions) notifyWatchers(ctx context.Context, act *Action) error {
actions = append(actions, clone(watch.UserID))
}
return db.Create(actions).Error
return s.Create(actions).Error
}
func (db *actions) NewRepo(ctx context.Context, doer, owner *User, repo *Repository) error {
func (s *actionsStore) NewRepo(ctx context.Context, doer, owner *User, repo *Repository) error {
opType := ActionCreateRepo
if repo.IsFork {
opType = ActionForkRepo
}
return db.notifyWatchers(ctx,
return s.notifyWatchers(ctx,
&Action{
ActUserID: doer.ID,
ActUserName: doer.Name,
@ -209,8 +209,8 @@ func (db *actions) NewRepo(ctx context.Context, doer, owner *User, repo *Reposit
)
}
func (db *actions) RenameRepo(ctx context.Context, doer, owner *User, oldRepoName string, repo *Repository) error {
return db.notifyWatchers(ctx,
func (s *actionsStore) RenameRepo(ctx context.Context, doer, owner *User, oldRepoName string, repo *Repository) error {
return s.notifyWatchers(ctx,
&Action{
ActUserID: doer.ID,
ActUserName: doer.Name,
@ -224,8 +224,8 @@ func (db *actions) RenameRepo(ctx context.Context, doer, owner *User, oldRepoNam
)
}
func (db *actions) mirrorSyncAction(ctx context.Context, opType ActionType, owner *User, repo *Repository, refName string, content []byte) error {
return db.notifyWatchers(ctx,
func (s *actionsStore) mirrorSyncAction(ctx context.Context, opType ActionType, owner *User, repo *Repository, refName string, content []byte) error {
return s.notifyWatchers(ctx,
&Action{
ActUserID: owner.ID,
ActUserName: owner.Name,
@ -249,13 +249,13 @@ type MirrorSyncPushOptions struct {
Commits *PushCommits
}
func (db *actions) MirrorSyncPush(ctx context.Context, opts MirrorSyncPushOptions) error {
func (s *actionsStore) MirrorSyncPush(ctx context.Context, opts MirrorSyncPushOptions) error {
if conf.UI.FeedMaxCommitNum > 0 && len(opts.Commits.Commits) > conf.UI.FeedMaxCommitNum {
opts.Commits.Commits = opts.Commits.Commits[:conf.UI.FeedMaxCommitNum]
}
apiCommits, err := opts.Commits.APIFormat(ctx,
NewUsersStore(db.DB),
NewUsersStore(s.DB),
repoutil.RepositoryPath(opts.Owner.Name, opts.Repo.Name),
repoutil.HTMLURL(opts.Owner.Name, opts.Repo.Name),
)
@ -288,19 +288,19 @@ func (db *actions) MirrorSyncPush(ctx context.Context, opts MirrorSyncPushOption
return errors.Wrap(err, "marshal JSON")
}
return db.mirrorSyncAction(ctx, ActionMirrorSyncPush, opts.Owner, opts.Repo, opts.RefName, data)
return s.mirrorSyncAction(ctx, ActionMirrorSyncPush, opts.Owner, opts.Repo, opts.RefName, data)
}
func (db *actions) MirrorSyncCreate(ctx context.Context, owner *User, repo *Repository, refName string) error {
return db.mirrorSyncAction(ctx, ActionMirrorSyncCreate, owner, repo, refName, nil)
func (s *actionsStore) MirrorSyncCreate(ctx context.Context, owner *User, repo *Repository, refName string) error {
return s.mirrorSyncAction(ctx, ActionMirrorSyncCreate, owner, repo, refName, nil)
}
func (db *actions) MirrorSyncDelete(ctx context.Context, owner *User, repo *Repository, refName string) error {
return db.mirrorSyncAction(ctx, ActionMirrorSyncDelete, owner, repo, refName, nil)
func (s *actionsStore) MirrorSyncDelete(ctx context.Context, owner *User, repo *Repository, refName string) error {
return s.mirrorSyncAction(ctx, ActionMirrorSyncDelete, owner, repo, refName, nil)
}
func (db *actions) MergePullRequest(ctx context.Context, doer, owner *User, repo *Repository, pull *Issue) error {
return db.notifyWatchers(ctx,
func (s *actionsStore) MergePullRequest(ctx context.Context, doer, owner *User, repo *Repository, pull *Issue) error {
return s.notifyWatchers(ctx,
&Action{
ActUserID: doer.ID,
ActUserName: doer.Name,
@ -314,8 +314,8 @@ func (db *actions) MergePullRequest(ctx context.Context, doer, owner *User, repo
)
}
func (db *actions) TransferRepo(ctx context.Context, doer, oldOwner, newOwner *User, repo *Repository) error {
return db.notifyWatchers(ctx,
func (s *actionsStore) TransferRepo(ctx context.Context, doer, oldOwner, newOwner *User, repo *Repository) error {
return s.notifyWatchers(ctx,
&Action{
ActUserID: doer.ID,
ActUserName: doer.Name,
@ -487,13 +487,13 @@ type CommitRepoOptions struct {
Commits *PushCommits
}
func (db *actions) CommitRepo(ctx context.Context, opts CommitRepoOptions) error {
err := NewReposStore(db.DB).Touch(ctx, opts.Repo.ID)
func (s *actionsStore) CommitRepo(ctx context.Context, opts CommitRepoOptions) error {
err := NewReposStore(s.DB).Touch(ctx, opts.Repo.ID)
if err != nil {
return errors.Wrap(err, "touch repository")
}
pusher, err := NewUsersStore(db.DB).GetByUsername(ctx, opts.PusherName)
pusher, err := NewUsersStore(s.DB).GetByUsername(ctx, opts.PusherName)
if err != nil {
return errors.Wrapf(err, "get pusher [name: %s]", opts.PusherName)
}
@ -536,7 +536,7 @@ func (db *actions) CommitRepo(ctx context.Context, opts CommitRepoOptions) error
}
action.OpType = ActionDeleteBranch
err = db.notifyWatchers(ctx, action)
err = s.notifyWatchers(ctx, action)
if err != nil {
return errors.Wrap(err, "notify watchers")
}
@ -580,7 +580,7 @@ func (db *actions) CommitRepo(ctx context.Context, opts CommitRepoOptions) error
}
action.OpType = ActionCreateBranch
err = db.notifyWatchers(ctx, action)
err = s.notifyWatchers(ctx, action)
if err != nil {
return errors.Wrap(err, "notify watchers")
}
@ -589,7 +589,7 @@ func (db *actions) CommitRepo(ctx context.Context, opts CommitRepoOptions) error
}
commits, err := opts.Commits.APIFormat(ctx,
NewUsersStore(db.DB),
NewUsersStore(s.DB),
repoutil.RepositoryPath(opts.Owner.Name, opts.Repo.Name),
repoutil.HTMLURL(opts.Owner.Name, opts.Repo.Name),
)
@ -616,7 +616,7 @@ func (db *actions) CommitRepo(ctx context.Context, opts CommitRepoOptions) error
}
action.OpType = ActionCommitRepo
err = db.notifyWatchers(ctx, action)
err = s.notifyWatchers(ctx, action)
if err != nil {
return errors.Wrap(err, "notify watchers")
}
@ -631,13 +631,13 @@ type PushTagOptions struct {
NewCommitID string
}
func (db *actions) PushTag(ctx context.Context, opts PushTagOptions) error {
err := NewReposStore(db.DB).Touch(ctx, opts.Repo.ID)
func (s *actionsStore) PushTag(ctx context.Context, opts PushTagOptions) error {
err := NewReposStore(s.DB).Touch(ctx, opts.Repo.ID)
if err != nil {
return errors.Wrap(err, "touch repository")
}
pusher, err := NewUsersStore(db.DB).GetByUsername(ctx, opts.PusherName)
pusher, err := NewUsersStore(s.DB).GetByUsername(ctx, opts.PusherName)
if err != nil {
return errors.Wrapf(err, "get pusher [name: %s]", opts.PusherName)
}
@ -672,7 +672,7 @@ func (db *actions) PushTag(ctx context.Context, opts PushTagOptions) error {
}
action.OpType = ActionDeleteTag
err = db.notifyWatchers(ctx, action)
err = s.notifyWatchers(ctx, action)
if err != nil {
return errors.Wrap(err, "notify watchers")
}
@ -696,7 +696,7 @@ func (db *actions) PushTag(ctx context.Context, opts PushTagOptions) error {
}
action.OpType = ActionPushTag
err = db.notifyWatchers(ctx, action)
err = s.notifyWatchers(ctx, action)
if err != nil {
return errors.Wrap(err, "notify watchers")
}

View File

@ -99,13 +99,13 @@ func TestActions(t *testing.T) {
ctx := context.Background()
t.Parallel()
db := &actions{
DB: newTestDB(t, "actions"),
db := &actionsStore{
DB: newTestDB(t, "actionsStore"),
}
for _, tc := range []struct {
name string
test func(t *testing.T, ctx context.Context, db *actions)
test func(t *testing.T, ctx context.Context, db *actionsStore)
}{
{"CommitRepo", actionsCommitRepo},
{"ListByOrganization", actionsListByOrganization},
@ -132,7 +132,7 @@ func TestActions(t *testing.T) {
}
}
func actionsCommitRepo(t *testing.T, ctx context.Context, db *actions) {
func actionsCommitRepo(t *testing.T, ctx context.Context, db *actionsStore) {
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
repo, err := NewReposStore(db.DB).Create(ctx,
@ -324,7 +324,7 @@ func actionsCommitRepo(t *testing.T, ctx context.Context, db *actions) {
})
}
func actionsListByOrganization(t *testing.T, ctx context.Context, db *actions) {
func actionsListByOrganization(t *testing.T, ctx context.Context, db *actionsStore) {
if os.Getenv("GOGS_DATABASE_TYPE") != "postgres" {
t.Skip("Skipping testing with not using PostgreSQL")
return
@ -363,14 +363,14 @@ func actionsListByOrganization(t *testing.T, ctx context.Context, db *actions) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
got := db.DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
return NewActionsStore(tx).(*actions).listByOrganization(ctx, test.orgID, test.actorID, test.afterID).Find(new(Action))
return NewActionsStore(tx).(*actionsStore).listByOrganization(ctx, test.orgID, test.actorID, test.afterID).Find(new(Action))
})
assert.Equal(t, test.want, got)
})
}
}
func actionsListByUser(t *testing.T, ctx context.Context, db *actions) {
func actionsListByUser(t *testing.T, ctx context.Context, db *actionsStore) {
if os.Getenv("GOGS_DATABASE_TYPE") != "postgres" {
t.Skip("Skipping testing with not using PostgreSQL")
return
@ -428,14 +428,14 @@ func actionsListByUser(t *testing.T, ctx context.Context, db *actions) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
got := db.DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
return NewActionsStore(tx).(*actions).listByUser(ctx, test.userID, test.actorID, test.afterID, test.isProfile).Find(new(Action))
return NewActionsStore(tx).(*actionsStore).listByUser(ctx, test.userID, test.actorID, test.afterID, test.isProfile).Find(new(Action))
})
assert.Equal(t, test.want, got)
})
}
}
func actionsMergePullRequest(t *testing.T, ctx context.Context, db *actions) {
func actionsMergePullRequest(t *testing.T, ctx context.Context, db *actionsStore) {
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
repo, err := NewReposStore(db.DB).Create(ctx,
@ -480,7 +480,7 @@ func actionsMergePullRequest(t *testing.T, ctx context.Context, db *actions) {
assert.Equal(t, want, got)
}
func actionsMirrorSyncCreate(t *testing.T, ctx context.Context, db *actions) {
func actionsMirrorSyncCreate(t *testing.T, ctx context.Context, db *actionsStore) {
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
repo, err := NewReposStore(db.DB).Create(ctx,
@ -521,7 +521,7 @@ func actionsMirrorSyncCreate(t *testing.T, ctx context.Context, db *actions) {
assert.Equal(t, want, got)
}
func actionsMirrorSyncDelete(t *testing.T, ctx context.Context, db *actions) {
func actionsMirrorSyncDelete(t *testing.T, ctx context.Context, db *actionsStore) {
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
repo, err := NewReposStore(db.DB).Create(ctx,
@ -562,7 +562,7 @@ func actionsMirrorSyncDelete(t *testing.T, ctx context.Context, db *actions) {
assert.Equal(t, want, got)
}
func actionsMirrorSyncPush(t *testing.T, ctx context.Context, db *actions) {
func actionsMirrorSyncPush(t *testing.T, ctx context.Context, db *actionsStore) {
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
repo, err := NewReposStore(db.DB).Create(ctx,
@ -627,7 +627,7 @@ func actionsMirrorSyncPush(t *testing.T, ctx context.Context, db *actions) {
assert.Equal(t, want, got)
}
func actionsNewRepo(t *testing.T, ctx context.Context, db *actions) {
func actionsNewRepo(t *testing.T, ctx context.Context, db *actionsStore) {
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
repo, err := NewReposStore(db.DB).Create(ctx,
@ -702,7 +702,7 @@ func actionsNewRepo(t *testing.T, ctx context.Context, db *actions) {
})
}
func actionsPushTag(t *testing.T, ctx context.Context, db *actions) {
func actionsPushTag(t *testing.T, ctx context.Context, db *actionsStore) {
// NOTE: We set a noop mock here to avoid data race with other tests that writes
// to the mock server because this function holds a lock.
conf.SetMockServer(t, conf.ServerOpts{})
@ -798,7 +798,7 @@ func actionsPushTag(t *testing.T, ctx context.Context, db *actions) {
})
}
func actionsRenameRepo(t *testing.T, ctx context.Context, db *actions) {
func actionsRenameRepo(t *testing.T, ctx context.Context, db *actionsStore) {
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
repo, err := NewReposStore(db.DB).Create(ctx,
@ -835,7 +835,7 @@ func actionsRenameRepo(t *testing.T, ctx context.Context, db *actions) {
assert.Equal(t, want, got)
}
func actionsTransferRepo(t *testing.T, ctx context.Context, db *actions) {
func actionsTransferRepo(t *testing.T, ctx context.Context, db *actionsStore) {
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
bob, err := NewUsersStore(db.DB).Create(ctx, "bob", "bob@example.com", CreateUserOptions{})

View File

@ -123,15 +123,15 @@ func Init(w logger.Writer) (*gorm.DB, error) {
}
// Initialize stores, sorted in alphabetical order.
AccessTokens = &accessTokens{DB: db}
AccessTokens = &accessTokensStore{DB: db}
Actions = NewActionsStore(db)
LoginSources = &loginSources{DB: db, files: sourceFiles}
LFS = &lfs{DB: db}
LoginSources = &loginSourcesStore{DB: db, files: sourceFiles}
LFS = &lfsStore{DB: db}
Notices = NewNoticesStore(db)
Orgs = NewOrgsStore(db)
Perms = NewPermsStore(db)
Repos = NewReposStore(db)
TwoFactors = &twoFactors{DB: db}
TwoFactors = &twoFactorsStore{DB: db}
Users = NewUsersStore(db)
return db, nil

View File

@ -38,20 +38,20 @@ type LFSObject struct {
CreatedAt time.Time `gorm:"not null"`
}
var _ LFSStore = (*lfs)(nil)
var _ LFSStore = (*lfsStore)(nil)
type lfs struct {
type lfsStore struct {
*gorm.DB
}
func (db *lfs) CreateObject(ctx context.Context, repoID int64, oid lfsutil.OID, size int64, storage lfsutil.Storage) error {
func (s *lfsStore) CreateObject(ctx context.Context, repoID int64, oid lfsutil.OID, size int64, storage lfsutil.Storage) error {
object := &LFSObject{
RepoID: repoID,
OID: oid,
Size: size,
Storage: storage,
}
return db.WithContext(ctx).Create(object).Error
return s.WithContext(ctx).Create(object).Error
}
type ErrLFSObjectNotExist struct {
@ -71,9 +71,9 @@ func (ErrLFSObjectNotExist) NotFound() bool {
return true
}
func (db *lfs) GetObjectByOID(ctx context.Context, repoID int64, oid lfsutil.OID) (*LFSObject, error) {
func (s *lfsStore) GetObjectByOID(ctx context.Context, repoID int64, oid lfsutil.OID) (*LFSObject, error) {
object := new(LFSObject)
err := db.WithContext(ctx).Where("repo_id = ? AND oid = ?", repoID, oid).First(object).Error
err := s.WithContext(ctx).Where("repo_id = ? AND oid = ?", repoID, oid).First(object).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, ErrLFSObjectNotExist{args: errutil.Args{"repoID": repoID, "oid": oid}}
@ -83,13 +83,13 @@ func (db *lfs) GetObjectByOID(ctx context.Context, repoID int64, oid lfsutil.OID
return object, err
}
func (db *lfs) GetObjectsByOIDs(ctx context.Context, repoID int64, oids ...lfsutil.OID) ([]*LFSObject, error) {
func (s *lfsStore) GetObjectsByOIDs(ctx context.Context, repoID int64, oids ...lfsutil.OID) ([]*LFSObject, error) {
if len(oids) == 0 {
return []*LFSObject{}, nil
}
objects := make([]*LFSObject, 0, len(oids))
err := db.WithContext(ctx).Where("repo_id = ? AND oid IN (?)", repoID, oids).Find(&objects).Error
err := s.WithContext(ctx).Where("repo_id = ? AND oid IN (?)", repoID, oids).Find(&objects).Error
if err != nil && err != gorm.ErrRecordNotFound {
return nil, err
}

View File

@ -23,13 +23,13 @@ func TestLFS(t *testing.T) {
t.Parallel()
ctx := context.Background()
db := &lfs{
DB: newTestDB(t, "lfs"),
db := &lfsStore{
DB: newTestDB(t, "lfsStore"),
}
for _, tc := range []struct {
name string
test func(t *testing.T, ctx context.Context, db *lfs)
test func(t *testing.T, ctx context.Context, db *lfsStore)
}{
{"CreateObject", lfsCreateObject},
{"GetObjectByOID", lfsGetObjectByOID},
@ -48,7 +48,7 @@ func TestLFS(t *testing.T) {
}
}
func lfsCreateObject(t *testing.T, ctx context.Context, db *lfs) {
func lfsCreateObject(t *testing.T, ctx context.Context, db *lfsStore) {
// Create first LFS object
repoID := int64(1)
oid := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f")
@ -65,7 +65,7 @@ func lfsCreateObject(t *testing.T, ctx context.Context, db *lfs) {
assert.Error(t, err)
}
func lfsGetObjectByOID(t *testing.T, ctx context.Context, db *lfs) {
func lfsGetObjectByOID(t *testing.T, ctx context.Context, db *lfsStore) {
// Create a LFS object
repoID := int64(1)
oid := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f")
@ -82,7 +82,7 @@ func lfsGetObjectByOID(t *testing.T, ctx context.Context, db *lfs) {
assert.Equal(t, expErr, err)
}
func lfsGetObjectsByOIDs(t *testing.T, ctx context.Context, db *lfs) {
func lfsGetObjectsByOIDs(t *testing.T, ctx context.Context, db *lfsStore) {
// Create two LFS objects
repoID := int64(1)
oid1 := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f")

View File

@ -180,9 +180,9 @@ func (s *LoginSource) GitHub() *github.Config {
return s.Provider.Config().(*github.Config)
}
var _ LoginSourcesStore = (*loginSources)(nil)
var _ LoginSourcesStore = (*loginSourcesStore)(nil)
type loginSources struct {
type loginSourcesStore struct {
*gorm.DB
files loginSourceFilesStore
}
@ -208,8 +208,8 @@ func (err ErrLoginSourceAlreadyExist) Error() string {
return fmt.Sprintf("login source already exists: %v", err.args)
}
func (db *loginSources) Create(ctx context.Context, opts CreateLoginSourceOptions) (*LoginSource, error) {
err := db.WithContext(ctx).Where("name = ?", opts.Name).First(new(LoginSource)).Error
func (s *loginSourcesStore) Create(ctx context.Context, opts CreateLoginSourceOptions) (*LoginSource, error) {
err := s.WithContext(ctx).Where("name = ?", opts.Name).First(new(LoginSource)).Error
if err == nil {
return nil, ErrLoginSourceAlreadyExist{args: errutil.Args{"name": opts.Name}}
} else if err != gorm.ErrRecordNotFound {
@ -226,13 +226,13 @@ func (db *loginSources) Create(ctx context.Context, opts CreateLoginSourceOption
if err != nil {
return nil, err
}
return source, db.WithContext(ctx).Create(source).Error
return source, s.WithContext(ctx).Create(source).Error
}
func (db *loginSources) Count(ctx context.Context) int64 {
func (s *loginSourcesStore) Count(ctx context.Context) int64 {
var count int64
db.WithContext(ctx).Model(new(LoginSource)).Count(&count)
return count + int64(db.files.Len())
s.WithContext(ctx).Model(new(LoginSource)).Count(&count)
return count + int64(s.files.Len())
}
type ErrLoginSourceInUse struct {
@ -248,24 +248,24 @@ func (err ErrLoginSourceInUse) Error() string {
return fmt.Sprintf("login source is still used by some users: %v", err.args)
}
func (db *loginSources) DeleteByID(ctx context.Context, id int64) error {
func (s *loginSourcesStore) DeleteByID(ctx context.Context, id int64) error {
var count int64
err := db.WithContext(ctx).Model(new(User)).Where("login_source = ?", id).Count(&count).Error
err := s.WithContext(ctx).Model(new(User)).Where("login_source = ?", id).Count(&count).Error
if err != nil {
return err
} else if count > 0 {
return ErrLoginSourceInUse{args: errutil.Args{"id": id}}
}
return db.WithContext(ctx).Where("id = ?", id).Delete(new(LoginSource)).Error
return s.WithContext(ctx).Where("id = ?", id).Delete(new(LoginSource)).Error
}
func (db *loginSources) GetByID(ctx context.Context, id int64) (*LoginSource, error) {
func (s *loginSourcesStore) GetByID(ctx context.Context, id int64) (*LoginSource, error) {
source := new(LoginSource)
err := db.WithContext(ctx).Where("id = ?", id).First(source).Error
err := s.WithContext(ctx).Where("id = ?", id).First(source).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return db.files.GetByID(id)
return s.files.GetByID(id)
}
return nil, err
}
@ -277,9 +277,9 @@ type ListLoginSourceOptions struct {
OnlyActivated bool
}
func (db *loginSources) List(ctx context.Context, opts ListLoginSourceOptions) ([]*LoginSource, error) {
func (s *loginSourcesStore) List(ctx context.Context, opts ListLoginSourceOptions) ([]*LoginSource, error) {
var sources []*LoginSource
query := db.WithContext(ctx).Order("id ASC")
query := s.WithContext(ctx).Order("id ASC")
if opts.OnlyActivated {
query = query.Where("is_actived = ?", true)
}
@ -288,11 +288,11 @@ func (db *loginSources) List(ctx context.Context, opts ListLoginSourceOptions) (
return nil, err
}
return append(sources, db.files.List(opts)...), nil
return append(sources, s.files.List(opts)...), nil
}
func (db *loginSources) ResetNonDefault(ctx context.Context, dflt *LoginSource) error {
err := db.WithContext(ctx).
func (s *loginSourcesStore) ResetNonDefault(ctx context.Context, dflt *LoginSource) error {
err := s.WithContext(ctx).
Model(new(LoginSource)).
Where("id != ?", dflt.ID).
Updates(map[string]any{"is_default": false}).
@ -301,7 +301,7 @@ func (db *loginSources) ResetNonDefault(ctx context.Context, dflt *LoginSource)
return err
}
for _, source := range db.files.List(ListLoginSourceOptions{}) {
for _, source := range s.files.List(ListLoginSourceOptions{}) {
if source.File != nil && source.ID != dflt.ID {
source.File.SetGeneral("is_default", "false")
if err = source.File.Save(); err != nil {
@ -310,13 +310,13 @@ func (db *loginSources) ResetNonDefault(ctx context.Context, dflt *LoginSource)
}
}
db.files.Update(dflt)
s.files.Update(dflt)
return nil
}
func (db *loginSources) Save(ctx context.Context, source *LoginSource) error {
func (s *loginSourcesStore) Save(ctx context.Context, source *LoginSource) error {
if source.File == nil {
return db.WithContext(ctx).Save(source).Error
return s.WithContext(ctx).Save(source).Error
}
source.File.SetGeneral("name", source.Name)

View File

@ -163,13 +163,13 @@ func TestLoginSources(t *testing.T) {
t.Parallel()
ctx := context.Background()
db := &loginSources{
DB: newTestDB(t, "loginSources"),
db := &loginSourcesStore{
DB: newTestDB(t, "loginSourcesStore"),
}
for _, tc := range []struct {
name string
test func(t *testing.T, ctx context.Context, db *loginSources)
test func(t *testing.T, ctx context.Context, db *loginSourcesStore)
}{
{"Create", loginSourcesCreate},
{"Count", loginSourcesCount},
@ -192,7 +192,7 @@ func TestLoginSources(t *testing.T) {
}
}
func loginSourcesCreate(t *testing.T, ctx context.Context, db *loginSources) {
func loginSourcesCreate(t *testing.T, ctx context.Context, db *loginSourcesStore) {
// Create first login source with name "GitHub"
source, err := db.Create(ctx,
CreateLoginSourceOptions{
@ -219,7 +219,7 @@ func loginSourcesCreate(t *testing.T, ctx context.Context, db *loginSources) {
assert.Equal(t, wantErr, err)
}
func loginSourcesCount(t *testing.T, ctx context.Context, db *loginSources) {
func loginSourcesCount(t *testing.T, ctx context.Context, db *loginSourcesStore) {
// Create two login sources, one in database and one as source file.
_, err := db.Create(ctx,
CreateLoginSourceOptions{
@ -241,7 +241,7 @@ func loginSourcesCount(t *testing.T, ctx context.Context, db *loginSources) {
assert.Equal(t, int64(3), db.Count(ctx))
}
func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSources) {
func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSourcesStore) {
t.Run("delete but in used", func(t *testing.T) {
source, err := db.Create(ctx,
CreateLoginSourceOptions{
@ -257,7 +257,7 @@ func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSources)
require.NoError(t, err)
// Create a user that uses this login source
_, err = (&users{DB: db.DB}).Create(ctx, "alice", "",
_, err = (&usersStore{DB: db.DB}).Create(ctx, "alice", "",
CreateUserOptions{
LoginSource: source.ID,
},
@ -308,7 +308,7 @@ func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSources)
assert.Equal(t, wantErr, err)
}
func loginSourcesGetByID(t *testing.T, ctx context.Context, db *loginSources) {
func loginSourcesGetByID(t *testing.T, ctx context.Context, db *loginSourcesStore) {
mock := NewMockLoginSourceFilesStore()
mock.GetByIDFunc.SetDefaultHook(func(id int64) (*LoginSource, error) {
if id != 101 {
@ -344,7 +344,7 @@ func loginSourcesGetByID(t *testing.T, ctx context.Context, db *loginSources) {
require.NoError(t, err)
}
func loginSourcesList(t *testing.T, ctx context.Context, db *loginSources) {
func loginSourcesList(t *testing.T, ctx context.Context, db *loginSourcesStore) {
mock := NewMockLoginSourceFilesStore()
mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource {
if opts.OnlyActivated {
@ -393,7 +393,7 @@ func loginSourcesList(t *testing.T, ctx context.Context, db *loginSources) {
assert.Equal(t, 2, len(sources), "number of sources")
}
func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSources) {
func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSourcesStore) {
mock := NewMockLoginSourceFilesStore()
mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource {
mockFile := NewMockLoginSourceFileStore()
@ -448,7 +448,7 @@ func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSou
assert.False(t, source2.IsDefault)
}
func loginSourcesSave(t *testing.T, ctx context.Context, db *loginSources) {
func loginSourcesSave(t *testing.T, ctx context.Context, db *loginSourcesStore) {
t.Run("save to database", func(t *testing.T) {
// Create a login source with name "GitHub"
source, err := db.Create(ctx,

View File

@ -32,7 +32,7 @@ func setMockLoginSourcesStore(t *testing.T, mock LoginSourcesStore) {
})
}
func setMockLoginSourceFilesStore(t *testing.T, db *loginSources, mock loginSourceFilesStore) {
func setMockLoginSourceFilesStore(t *testing.T, db *loginSourcesStore, mock loginSourceFilesStore) {
before := db.files
db.files = mock
t.Cleanup(func() {

View File

@ -32,20 +32,20 @@ type NoticesStore interface {
var Notices NoticesStore
var _ NoticesStore = (*notices)(nil)
var _ NoticesStore = (*noticesStore)(nil)
type notices struct {
type noticesStore struct {
*gorm.DB
}
// NewNoticesStore returns a persistent interface for system notices with given
// database connection.
func NewNoticesStore(db *gorm.DB) NoticesStore {
return &notices{DB: db}
return &noticesStore{DB: db}
}
func (db *notices) Create(ctx context.Context, typ NoticeType, desc string) error {
return db.WithContext(ctx).Create(
func (s *noticesStore) Create(ctx context.Context, typ NoticeType, desc string) error {
return s.WithContext(ctx).Create(
&Notice{
Type: typ,
Description: desc,
@ -53,26 +53,26 @@ func (db *notices) Create(ctx context.Context, typ NoticeType, desc string) erro
).Error
}
func (db *notices) DeleteByIDs(ctx context.Context, ids ...int64) error {
return db.WithContext(ctx).Where("id IN (?)", ids).Delete(&Notice{}).Error
func (s *noticesStore) DeleteByIDs(ctx context.Context, ids ...int64) error {
return s.WithContext(ctx).Where("id IN (?)", ids).Delete(&Notice{}).Error
}
func (db *notices) DeleteAll(ctx context.Context) error {
return db.WithContext(ctx).Where("TRUE").Delete(&Notice{}).Error
func (s *noticesStore) DeleteAll(ctx context.Context) error {
return s.WithContext(ctx).Where("TRUE").Delete(&Notice{}).Error
}
func (db *notices) List(ctx context.Context, page, pageSize int) ([]*Notice, error) {
func (s *noticesStore) List(ctx context.Context, page, pageSize int) ([]*Notice, error) {
notices := make([]*Notice, 0, pageSize)
return notices, db.WithContext(ctx).
return notices, s.WithContext(ctx).
Limit(pageSize).Offset((page - 1) * pageSize).
Order("id DESC").
Find(&notices).
Error
}
func (db *notices) Count(ctx context.Context) int64 {
func (s *noticesStore) Count(ctx context.Context) int64 {
var count int64
db.WithContext(ctx).Model(&Notice{}).Count(&count)
s.WithContext(ctx).Model(&Notice{}).Count(&count)
return count
}

View File

@ -65,13 +65,13 @@ func TestNotices(t *testing.T) {
t.Parallel()
ctx := context.Background()
db := &notices{
DB: newTestDB(t, "notices"),
db := &noticesStore{
DB: newTestDB(t, "noticesStore"),
}
for _, tc := range []struct {
name string
test func(t *testing.T, ctx context.Context, db *notices)
test func(t *testing.T, ctx context.Context, db *noticesStore)
}{
{"Create", noticesCreate},
{"DeleteByIDs", noticesDeleteByIDs},
@ -92,7 +92,7 @@ func TestNotices(t *testing.T) {
}
}
func noticesCreate(t *testing.T, ctx context.Context, db *notices) {
func noticesCreate(t *testing.T, ctx context.Context, db *noticesStore) {
err := db.Create(ctx, NoticeTypeRepository, "test")
require.NoError(t, err)
@ -100,7 +100,7 @@ func noticesCreate(t *testing.T, ctx context.Context, db *notices) {
assert.Equal(t, int64(1), count)
}
func noticesDeleteByIDs(t *testing.T, ctx context.Context, db *notices) {
func noticesDeleteByIDs(t *testing.T, ctx context.Context, db *noticesStore) {
err := db.Create(ctx, NoticeTypeRepository, "test")
require.NoError(t, err)
@ -120,7 +120,7 @@ func noticesDeleteByIDs(t *testing.T, ctx context.Context, db *notices) {
assert.Equal(t, int64(0), count)
}
func noticesDeleteAll(t *testing.T, ctx context.Context, db *notices) {
func noticesDeleteAll(t *testing.T, ctx context.Context, db *noticesStore) {
err := db.Create(ctx, NoticeTypeRepository, "test")
require.NoError(t, err)
@ -131,7 +131,7 @@ func noticesDeleteAll(t *testing.T, ctx context.Context, db *notices) {
assert.Equal(t, int64(0), count)
}
func noticesList(t *testing.T, ctx context.Context, db *notices) {
func noticesList(t *testing.T, ctx context.Context, db *noticesStore) {
err := db.Create(ctx, NoticeTypeRepository, "test 1")
require.NoError(t, err)
err = db.Create(ctx, NoticeTypeRepository, "test 2")
@ -151,7 +151,7 @@ func noticesList(t *testing.T, ctx context.Context, db *notices) {
require.Len(t, got, 2)
}
func noticesCount(t *testing.T, ctx context.Context, db *notices) {
func noticesCount(t *testing.T, ctx context.Context, db *noticesStore) {
count := db.Count(ctx)
assert.Equal(t, int64(0), count)

View File

@ -30,16 +30,16 @@ type OrgsStore interface {
var Orgs OrgsStore
var _ OrgsStore = (*orgs)(nil)
var _ OrgsStore = (*orgsStore)(nil)
type orgs struct {
type orgsStore struct {
*gorm.DB
}
// NewOrgsStore returns a persistent interface for orgs with given database
// connection.
func NewOrgsStore(db *gorm.DB) OrgsStore {
return &orgs{DB: db}
return &orgsStore{DB: db}
}
type ListOrgsOptions struct {
@ -49,7 +49,7 @@ type ListOrgsOptions struct {
IncludePrivateMembers bool
}
func (db *orgs) List(ctx context.Context, opts ListOrgsOptions) ([]*Organization, error) {
func (s *orgsStore) List(ctx context.Context, opts ListOrgsOptions) ([]*Organization, error) {
if opts.MemberID <= 0 {
return nil, errors.New("MemberID must be greater than 0")
}
@ -64,7 +64,7 @@ func (db *orgs) List(ctx context.Context, opts ListOrgsOptions) ([]*Organization
[AND org_user.is_public = @includePrivateMembers]
ORDER BY org.id ASC
*/
tx := db.WithContext(ctx).
tx := s.WithContext(ctx).
Joins(dbutil.Quote("JOIN org_user ON org_user.org_id = %s.id", "user")).
Where("org_user.uid = ?", opts.MemberID).
Order(dbutil.Quote("%s.id ASC", "user"))
@ -76,13 +76,13 @@ func (db *orgs) List(ctx context.Context, opts ListOrgsOptions) ([]*Organization
return orgs, tx.Find(&orgs).Error
}
func (db *orgs) SearchByName(ctx context.Context, keyword string, page, pageSize int, orderBy string) ([]*Organization, int64, error) {
return searchUserByName(ctx, db.DB, UserTypeOrganization, keyword, page, pageSize, orderBy)
func (s *orgsStore) SearchByName(ctx context.Context, keyword string, page, pageSize int, orderBy string) ([]*Organization, int64, error) {
return searchUserByName(ctx, s.DB, UserTypeOrganization, keyword, page, pageSize, orderBy)
}
func (db *orgs) CountByUser(ctx context.Context, userID int64) (int64, error) {
func (s *orgsStore) CountByUser(ctx context.Context, userID int64) (int64, error) {
var count int64
return count, db.WithContext(ctx).Model(&OrgUser{}).Where("uid = ?", userID).Count(&count).Error
return count, s.WithContext(ctx).Model(&OrgUser{}).Where("uid = ?", userID).Count(&count).Error
}
type Organization = User

View File

@ -21,13 +21,13 @@ func TestOrgs(t *testing.T) {
t.Parallel()
ctx := context.Background()
db := &orgs{
DB: newTestDB(t, "orgs"),
db := &orgsStore{
DB: newTestDB(t, "orgsStore"),
}
for _, tc := range []struct {
name string
test func(t *testing.T, ctx context.Context, db *orgs)
test func(t *testing.T, ctx context.Context, db *orgsStore)
}{
{"List", orgsList},
{"SearchByName", orgsSearchByName},
@ -46,7 +46,7 @@ func TestOrgs(t *testing.T) {
}
}
func orgsList(t *testing.T, ctx context.Context, db *orgs) {
func orgsList(t *testing.T, ctx context.Context, db *orgsStore) {
usersStore := NewUsersStore(db.DB)
alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
@ -116,7 +116,7 @@ func orgsList(t *testing.T, ctx context.Context, db *orgs) {
}
}
func orgsSearchByName(t *testing.T, ctx context.Context, db *orgs) {
func orgsSearchByName(t *testing.T, ctx context.Context, db *orgsStore) {
// TODO: Use Orgs.Create to replace SQL hack when the method is available.
usersStore := NewUsersStore(db.DB)
org1, err := usersStore.Create(ctx, "org1", "org1@example.com", CreateUserOptions{FullName: "Acme Corp"})
@ -161,7 +161,7 @@ func orgsSearchByName(t *testing.T, ctx context.Context, db *orgs) {
})
}
func orgsCountByUser(t *testing.T, ctx context.Context, db *orgs) {
func orgsCountByUser(t *testing.T, ctx context.Context, db *orgsStore) {
// TODO: Use Orgs.Join to replace SQL hack when the method is available.
err := db.Exec(`INSERT INTO org_user (uid, org_id) VALUES (?, ?)`, 1, 1).Error
require.NoError(t, err)

View File

@ -74,16 +74,16 @@ func ParseAccessMode(permission string) AccessMode {
}
}
var _ PermsStore = (*perms)(nil)
var _ PermsStore = (*permsStore)(nil)
type perms struct {
type permsStore struct {
*gorm.DB
}
// NewPermsStore returns a persistent interface for permissions with given
// database connection.
func NewPermsStore(db *gorm.DB) PermsStore {
return &perms{DB: db}
return &permsStore{DB: db}
}
type AccessModeOptions struct {
@ -91,7 +91,7 @@ type AccessModeOptions struct {
Private bool // Whether the repository is private.
}
func (db *perms) AccessMode(ctx context.Context, userID, repoID int64, opts AccessModeOptions) (mode AccessMode) {
func (s *permsStore) AccessMode(ctx context.Context, userID, repoID int64, opts AccessModeOptions) (mode AccessMode) {
if repoID <= 0 {
return AccessModeNone
}
@ -111,7 +111,7 @@ func (db *perms) AccessMode(ctx context.Context, userID, repoID int64, opts Acce
}
access := new(Access)
err := db.WithContext(ctx).Where("user_id = ? AND repo_id = ?", userID, repoID).First(access).Error
err := s.WithContext(ctx).Where("user_id = ? AND repo_id = ?", userID, repoID).First(access).Error
if err != nil {
if err != gorm.ErrRecordNotFound {
log.Error("Failed to get access [user_id: %d, repo_id: %d]: %v", userID, repoID, err)
@ -121,11 +121,11 @@ func (db *perms) AccessMode(ctx context.Context, userID, repoID int64, opts Acce
return access.Mode
}
func (db *perms) Authorize(ctx context.Context, userID, repoID int64, desired AccessMode, opts AccessModeOptions) bool {
return desired <= db.AccessMode(ctx, userID, repoID, opts)
func (s *permsStore) Authorize(ctx context.Context, userID, repoID int64, desired AccessMode, opts AccessModeOptions) bool {
return desired <= s.AccessMode(ctx, userID, repoID, opts)
}
func (db *perms) SetRepoPerms(ctx context.Context, repoID int64, accessMap map[int64]AccessMode) error {
func (s *permsStore) SetRepoPerms(ctx context.Context, repoID int64, accessMap map[int64]AccessMode) error {
records := make([]*Access, 0, len(accessMap))
for userID, mode := range accessMap {
records = append(records, &Access{
@ -135,7 +135,7 @@ func (db *perms) SetRepoPerms(ctx context.Context, repoID int64, accessMap map[i
})
}
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
err := tx.Where("repo_id = ?", repoID).Delete(new(Access)).Error
if err != nil {
return err

View File

@ -19,13 +19,13 @@ func TestPerms(t *testing.T) {
t.Parallel()
ctx := context.Background()
db := &perms{
DB: newTestDB(t, "perms"),
db := &permsStore{
DB: newTestDB(t, "permsStore"),
}
for _, tc := range []struct {
name string
test func(t *testing.T, ctx context.Context, db *perms)
test func(t *testing.T, ctx context.Context, db *permsStore)
}{
{"AccessMode", permsAccessMode},
{"Authorize", permsAuthorize},
@ -44,7 +44,7 @@ func TestPerms(t *testing.T) {
}
}
func permsAccessMode(t *testing.T, ctx context.Context, db *perms) {
func permsAccessMode(t *testing.T, ctx context.Context, db *permsStore) {
// Set up permissions
err := db.SetRepoPerms(ctx, 1,
map[int64]AccessMode{
@ -155,7 +155,7 @@ func permsAccessMode(t *testing.T, ctx context.Context, db *perms) {
}
}
func permsAuthorize(t *testing.T, ctx context.Context, db *perms) {
func permsAuthorize(t *testing.T, ctx context.Context, db *permsStore) {
// Set up permissions
err := db.SetRepoPerms(ctx, 1,
map[int64]AccessMode{
@ -241,7 +241,7 @@ func permsAuthorize(t *testing.T, ctx context.Context, db *perms) {
}
}
func permsSetRepoPerms(t *testing.T, ctx context.Context, db *perms) {
func permsSetRepoPerms(t *testing.T, ctx context.Context, db *permsStore) {
for _, update := range []struct {
repoID int64
accessMap map[int64]AccessMode

View File

@ -24,23 +24,23 @@ type PublicKeysStore interface {
var PublicKeys PublicKeysStore
var _ PublicKeysStore = (*publicKeys)(nil)
var _ PublicKeysStore = (*publicKeysStore)(nil)
type publicKeys struct {
type publicKeysStore struct {
*gorm.DB
}
// NewPublicKeysStore returns a persistent interface for public keys with given
// database connection.
func NewPublicKeysStore(db *gorm.DB) PublicKeysStore {
return &publicKeys{DB: db}
return &publicKeysStore{DB: db}
}
func authorizedKeysPath() string {
return filepath.Join(conf.SSH.RootPath, "authorized_keys")
}
func (db *publicKeys) RewriteAuthorizedKeys() error {
func (s *publicKeysStore) RewriteAuthorizedKeys() error {
sshOpLocker.Lock()
defer sshOpLocker.Unlock()
@ -61,7 +61,7 @@ func (db *publicKeys) RewriteAuthorizedKeys() error {
// NOTE: More recently updated keys are more likely to be used more frequently,
// putting them in the earlier lines could speed up the key lookup by SSHD.
rows, err := db.Model(&PublicKey{}).Order("updated_unix DESC").Rows()
rows, err := s.Model(&PublicKey{}).Order("updated_unix DESC").Rows()
if err != nil {
return errors.Wrap(err, "iterate public keys")
}
@ -69,7 +69,7 @@ func (db *publicKeys) RewriteAuthorizedKeys() error {
for rows.Next() {
var key PublicKey
err = db.ScanRows(rows, &key)
err = s.ScanRows(rows, &key)
if err != nil {
return errors.Wrap(err, "scan rows")
}

View File

@ -24,13 +24,13 @@ func TestPublicKeys(t *testing.T) {
t.Parallel()
ctx := context.Background()
db := &publicKeys{
DB: newTestDB(t, "publicKeys"),
db := &publicKeysStore{
DB: newTestDB(t, "publicKeysStore"),
}
for _, tc := range []struct {
name string
test func(t *testing.T, ctx context.Context, db *publicKeys)
test func(t *testing.T, ctx context.Context, db *publicKeysStore)
}{
{"RewriteAuthorizedKeys", publicKeysRewriteAuthorizedKeys},
} {
@ -47,7 +47,7 @@ func TestPublicKeys(t *testing.T) {
}
}
func publicKeysRewriteAuthorizedKeys(t *testing.T, ctx context.Context, db *publicKeys) {
func publicKeysRewriteAuthorizedKeys(t *testing.T, ctx context.Context, db *publicKeysStore) {
// TODO: Use PublicKeys.Add to replace SQL hack when the method is available.
publicKey := &PublicKey{
OwnerID: 1,

View File

@ -119,16 +119,16 @@ func (r *Repository) APIFormat(owner *User, opts ...RepositoryAPIFormatOptions)
}
}
var _ ReposStore = (*repos)(nil)
var _ ReposStore = (*reposStore)(nil)
type repos struct {
type reposStore struct {
*gorm.DB
}
// NewReposStore returns a persistent interface for repositories with given
// database connection.
func NewReposStore(db *gorm.DB) ReposStore {
return &repos{DB: db}
return &reposStore{DB: db}
}
type ErrRepoAlreadyExist struct {
@ -157,13 +157,13 @@ type CreateRepoOptions struct {
ForkID int64
}
func (db *repos) Create(ctx context.Context, ownerID int64, opts CreateRepoOptions) (*Repository, error) {
func (s *reposStore) Create(ctx context.Context, ownerID int64, opts CreateRepoOptions) (*Repository, error) {
err := isRepoNameAllowed(opts.Name)
if err != nil {
return nil, err
}
_, err = db.GetByName(ctx, ownerID, opts.Name)
_, err = s.GetByName(ctx, ownerID, opts.Name)
if err == nil {
return nil, ErrRepoAlreadyExist{
args: errutil.Args{
@ -189,7 +189,7 @@ func (db *repos) Create(ctx context.Context, ownerID int64, opts CreateRepoOptio
IsFork: opts.Fork,
ForkID: opts.ForkID,
}
return repo, db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
return repo, s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
err = tx.Create(repo).Error
if err != nil {
return errors.Wrap(err, "create")
@ -203,7 +203,7 @@ func (db *repos) Create(ctx context.Context, ownerID int64, opts CreateRepoOptio
})
}
func (db *repos) GetByCollaboratorID(ctx context.Context, collaboratorID int64, limit int, orderBy string) ([]*Repository, error) {
func (s *reposStore) GetByCollaboratorID(ctx context.Context, collaboratorID int64, limit int, orderBy string) ([]*Repository, error) {
/*
Equivalent SQL for PostgreSQL:
@ -214,7 +214,7 @@ func (db *repos) GetByCollaboratorID(ctx context.Context, collaboratorID int64,
LIMIT @limit
*/
var repos []*Repository
return repos, db.WithContext(ctx).
return repos, s.WithContext(ctx).
Joins("JOIN access ON access.repo_id = repository.id AND access.user_id = ?", collaboratorID).
Where("access.mode >= ?", AccessModeRead).
Order(orderBy).
@ -223,7 +223,7 @@ func (db *repos) GetByCollaboratorID(ctx context.Context, collaboratorID int64,
Error
}
func (db *repos) GetByCollaboratorIDWithAccessMode(ctx context.Context, collaboratorID int64) (map[*Repository]AccessMode, error) {
func (s *reposStore) GetByCollaboratorIDWithAccessMode(ctx context.Context, collaboratorID int64) (map[*Repository]AccessMode, error) {
/*
Equivalent SQL for PostgreSQL:
@ -238,7 +238,7 @@ func (db *repos) GetByCollaboratorIDWithAccessMode(ctx context.Context, collabor
*Repository
Mode AccessMode
}
err := db.WithContext(ctx).
err := s.WithContext(ctx).
Select("repository.*", "access.mode").
Table("repository").
Joins("JOIN access ON access.repo_id = repository.id AND access.user_id = ?", collaboratorID).
@ -275,9 +275,9 @@ func (ErrRepoNotExist) NotFound() bool {
return true
}
func (db *repos) GetByID(ctx context.Context, id int64) (*Repository, error) {
func (s *reposStore) GetByID(ctx context.Context, id int64) (*Repository, error) {
repo := new(Repository)
err := db.WithContext(ctx).Where("id = ?", id).First(repo).Error
err := s.WithContext(ctx).Where("id = ?", id).First(repo).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, ErrRepoNotExist{errutil.Args{"repoID": id}}
@ -287,9 +287,9 @@ func (db *repos) GetByID(ctx context.Context, id int64) (*Repository, error) {
return repo, nil
}
func (db *repos) GetByName(ctx context.Context, ownerID int64, name string) (*Repository, error) {
func (s *reposStore) GetByName(ctx context.Context, ownerID int64, name string) (*Repository, error) {
repo := new(Repository)
err := db.WithContext(ctx).
err := s.WithContext(ctx).
Where("owner_id = ? AND lower_name = ?", ownerID, strings.ToLower(name)).
First(repo).
Error
@ -307,7 +307,7 @@ func (db *repos) GetByName(ctx context.Context, ownerID int64, name string) (*Re
return repo, nil
}
func (db *repos) recountStars(tx *gorm.DB, userID, repoID int64) error {
func (s *reposStore) recountStars(tx *gorm.DB, userID, repoID int64) error {
/*
Equivalent SQL for PostgreSQL:
@ -350,40 +350,40 @@ func (db *repos) recountStars(tx *gorm.DB, userID, repoID int64) error {
return nil
}
func (db *repos) Star(ctx context.Context, userID, repoID int64) error {
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
s := &Star{
func (s *reposStore) Star(ctx context.Context, userID, repoID int64) error {
return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
star := &Star{
UserID: userID,
RepoID: repoID,
}
result := tx.FirstOrCreate(s, s)
result := tx.FirstOrCreate(star, star)
if result.Error != nil {
return errors.Wrap(result.Error, "upsert")
} else if result.RowsAffected <= 0 {
return nil // Relation already exists
}
return db.recountStars(tx, userID, repoID)
return s.recountStars(tx, userID, repoID)
})
}
func (db *repos) Touch(ctx context.Context, id int64) error {
return db.WithContext(ctx).
func (s *reposStore) Touch(ctx context.Context, id int64) error {
return s.WithContext(ctx).
Model(new(Repository)).
Where("id = ?", id).
Updates(map[string]any{
"is_bare": false,
"updated_unix": db.NowFunc().Unix(),
"updated_unix": s.NowFunc().Unix(),
}).
Error
}
func (db *repos) ListWatches(ctx context.Context, repoID int64) ([]*Watch, error) {
func (s *reposStore) ListWatches(ctx context.Context, repoID int64) ([]*Watch, error) {
var watches []*Watch
return watches, db.WithContext(ctx).Where("repo_id = ?", repoID).Find(&watches).Error
return watches, s.WithContext(ctx).Where("repo_id = ?", repoID).Find(&watches).Error
}
func (db *repos) recountWatches(tx *gorm.DB, repoID int64) error {
func (s *reposStore) recountWatches(tx *gorm.DB, repoID int64) error {
/*
Equivalent SQL for PostgreSQL:
@ -402,8 +402,8 @@ func (db *repos) recountWatches(tx *gorm.DB, repoID int64) error {
Error
}
func (db *repos) Watch(ctx context.Context, userID, repoID int64) error {
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
func (s *reposStore) Watch(ctx context.Context, userID, repoID int64) error {
return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
w := &Watch{
UserID: userID,
RepoID: repoID,
@ -415,12 +415,12 @@ func (db *repos) Watch(ctx context.Context, userID, repoID int64) error {
return nil // Relation already exists
}
return db.recountWatches(tx, repoID)
return s.recountWatches(tx, repoID)
})
}
func (db *repos) HasForkedBy(ctx context.Context, repoID, userID int64) bool {
func (s *reposStore) HasForkedBy(ctx context.Context, repoID, userID int64) bool {
var count int64
db.WithContext(ctx).Model(new(Repository)).Where("owner_id = ? AND fork_id = ?", userID, repoID).Count(&count)
s.WithContext(ctx).Model(new(Repository)).Where("owner_id = ? AND fork_id = ?", userID, repoID).Count(&count)
return count > 0
}

View File

@ -85,13 +85,13 @@ func TestRepos(t *testing.T) {
t.Parallel()
ctx := context.Background()
db := &repos{
db := &reposStore{
DB: newTestDB(t, "repos"),
}
for _, tc := range []struct {
name string
test func(t *testing.T, ctx context.Context, db *repos)
test func(t *testing.T, ctx context.Context, db *reposStore)
}{
{"Create", reposCreate},
{"GetByCollaboratorID", reposGetByCollaboratorID},
@ -117,7 +117,7 @@ func TestRepos(t *testing.T) {
}
}
func reposCreate(t *testing.T, ctx context.Context, db *repos) {
func reposCreate(t *testing.T, ctx context.Context, db *reposStore) {
t.Run("name not allowed", func(t *testing.T) {
_, err := db.Create(ctx,
1,
@ -159,7 +159,7 @@ func reposCreate(t *testing.T, ctx context.Context, db *repos) {
assert.Equal(t, 1, repo.NumWatches) // The owner is watching the repo by default.
}
func reposGetByCollaboratorID(t *testing.T, ctx context.Context, db *repos) {
func reposGetByCollaboratorID(t *testing.T, ctx context.Context, db *reposStore) {
repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"})
require.NoError(t, err)
repo2, err := db.Create(ctx, 2, CreateRepoOptions{Name: "repo2"})
@ -185,7 +185,7 @@ func reposGetByCollaboratorID(t *testing.T, ctx context.Context, db *repos) {
})
}
func reposGetByCollaboratorIDWithAccessMode(t *testing.T, ctx context.Context, db *repos) {
func reposGetByCollaboratorIDWithAccessMode(t *testing.T, ctx context.Context, db *reposStore) {
repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"})
require.NoError(t, err)
repo2, err := db.Create(ctx, 2, CreateRepoOptions{Name: "repo2"})
@ -213,7 +213,7 @@ func reposGetByCollaboratorIDWithAccessMode(t *testing.T, ctx context.Context, d
assert.Equal(t, AccessModeAdmin, accessModes[repo2.ID])
}
func reposGetByID(t *testing.T, ctx context.Context, db *repos) {
func reposGetByID(t *testing.T, ctx context.Context, db *reposStore) {
repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"})
require.NoError(t, err)
@ -226,7 +226,7 @@ func reposGetByID(t *testing.T, ctx context.Context, db *repos) {
assert.Equal(t, wantErr, err)
}
func reposGetByName(t *testing.T, ctx context.Context, db *repos) {
func reposGetByName(t *testing.T, ctx context.Context, db *reposStore) {
repo, err := db.Create(ctx, 1,
CreateRepoOptions{
Name: "repo1",
@ -242,7 +242,7 @@ func reposGetByName(t *testing.T, ctx context.Context, db *repos) {
assert.Equal(t, wantErr, err)
}
func reposStar(t *testing.T, ctx context.Context, db *repos) {
func reposStar(t *testing.T, ctx context.Context, db *reposStore) {
repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"})
require.NoError(t, err)
usersStore := NewUsersStore(db.DB)
@ -261,7 +261,7 @@ func reposStar(t *testing.T, ctx context.Context, db *repos) {
assert.Equal(t, 1, alice.NumStars)
}
func reposTouch(t *testing.T, ctx context.Context, db *repos) {
func reposTouch(t *testing.T, ctx context.Context, db *reposStore) {
repo, err := db.Create(ctx, 1,
CreateRepoOptions{
Name: "repo1",
@ -287,7 +287,7 @@ func reposTouch(t *testing.T, ctx context.Context, db *repos) {
assert.False(t, got.IsBare)
}
func reposListWatches(t *testing.T, ctx context.Context, db *repos) {
func reposListWatches(t *testing.T, ctx context.Context, db *reposStore) {
err := db.Watch(ctx, 1, 1)
require.NoError(t, err)
err = db.Watch(ctx, 2, 1)
@ -308,7 +308,7 @@ func reposListWatches(t *testing.T, ctx context.Context, db *repos) {
assert.Equal(t, want, got)
}
func reposWatch(t *testing.T, ctx context.Context, db *repos) {
func reposWatch(t *testing.T, ctx context.Context, db *reposStore) {
reposStore := NewReposStore(db.DB)
repo1, err := reposStore.Create(ctx, 1, CreateRepoOptions{Name: "repo1"})
require.NoError(t, err)
@ -325,7 +325,7 @@ func reposWatch(t *testing.T, ctx context.Context, db *repos) {
assert.Equal(t, 2, repo1.NumWatches) // The owner is watching the repo by default.
}
func reposHasForkedBy(t *testing.T, ctx context.Context, db *repos) {
func reposHasForkedBy(t *testing.T, ctx context.Context, db *reposStore) {
has := db.HasForkedBy(ctx, 1, 2)
assert.False(t, has)

View File

@ -50,13 +50,13 @@ func (t *TwoFactor) AfterFind(_ *gorm.DB) error {
return nil
}
var _ TwoFactorsStore = (*twoFactors)(nil)
var _ TwoFactorsStore = (*twoFactorsStore)(nil)
type twoFactors struct {
type twoFactorsStore struct {
*gorm.DB
}
func (db *twoFactors) Create(ctx context.Context, userID int64, key, secret string) error {
func (s *twoFactorsStore) Create(ctx context.Context, userID int64, key, secret string) error {
encrypted, err := cryptoutil.AESGCMEncrypt(cryptoutil.MD5Bytes(key), []byte(secret))
if err != nil {
return errors.Wrap(err, "encrypt secret")
@ -71,7 +71,7 @@ func (db *twoFactors) Create(ctx context.Context, userID int64, key, secret stri
return errors.Wrap(err, "generate recovery codes")
}
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
err := tx.Create(tf).Error
if err != nil {
return err
@ -100,9 +100,9 @@ func (ErrTwoFactorNotFound) NotFound() bool {
return true
}
func (db *twoFactors) GetByUserID(ctx context.Context, userID int64) (*TwoFactor, error) {
func (s *twoFactorsStore) GetByUserID(ctx context.Context, userID int64) (*TwoFactor, error) {
tf := new(TwoFactor)
err := db.WithContext(ctx).Where("user_id = ?", userID).First(tf).Error
err := s.WithContext(ctx).Where("user_id = ?", userID).First(tf).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, ErrTwoFactorNotFound{args: errutil.Args{"userID": userID}}
@ -112,9 +112,9 @@ func (db *twoFactors) GetByUserID(ctx context.Context, userID int64) (*TwoFactor
return tf, nil
}
func (db *twoFactors) IsEnabled(ctx context.Context, userID int64) bool {
func (s *twoFactorsStore) IsEnabled(ctx context.Context, userID int64) bool {
var count int64
err := db.WithContext(ctx).Model(new(TwoFactor)).Where("user_id = ?", userID).Count(&count).Error
err := s.WithContext(ctx).Model(new(TwoFactor)).Where("user_id = ?", userID).Count(&count).Error
if err != nil {
log.Error("Failed to count two factors [user_id: %d]: %v", userID, err)
}

View File

@ -67,13 +67,13 @@ func TestTwoFactors(t *testing.T) {
t.Parallel()
ctx := context.Background()
db := &twoFactors{
DB: newTestDB(t, "twoFactors"),
db := &twoFactorsStore{
DB: newTestDB(t, "twoFactorsStore"),
}
for _, tc := range []struct {
name string
test func(t *testing.T, ctx context.Context, db *twoFactors)
test func(t *testing.T, ctx context.Context, db *twoFactorsStore)
}{
{"Create", twoFactorsCreate},
{"GetByUserID", twoFactorsGetByUserID},
@ -92,7 +92,7 @@ func TestTwoFactors(t *testing.T) {
}
}
func twoFactorsCreate(t *testing.T, ctx context.Context, db *twoFactors) {
func twoFactorsCreate(t *testing.T, ctx context.Context, db *twoFactorsStore) {
// Create a 2FA token
err := db.Create(ctx, 1, "secure-key", "secure-secret")
require.NoError(t, err)
@ -109,7 +109,7 @@ func twoFactorsCreate(t *testing.T, ctx context.Context, db *twoFactors) {
assert.Equal(t, int64(10), count)
}
func twoFactorsGetByUserID(t *testing.T, ctx context.Context, db *twoFactors) {
func twoFactorsGetByUserID(t *testing.T, ctx context.Context, db *twoFactorsStore) {
// Create a 2FA token for user 1
err := db.Create(ctx, 1, "secure-key", "secure-secret")
require.NoError(t, err)
@ -124,7 +124,7 @@ func twoFactorsGetByUserID(t *testing.T, ctx context.Context, db *twoFactors) {
assert.Equal(t, wantErr, err)
}
func twoFactorsIsEnabled(t *testing.T, ctx context.Context, db *twoFactors) {
func twoFactorsIsEnabled(t *testing.T, ctx context.Context, db *twoFactorsStore) {
// Create a 2FA token for user 1
err := db.Create(ctx, 1, "secure-key", "secure-secret")
require.NoError(t, err)

View File

@ -146,16 +146,16 @@ type UsersStore interface {
var Users UsersStore
var _ UsersStore = (*users)(nil)
var _ UsersStore = (*usersStore)(nil)
type users struct {
type usersStore struct {
*gorm.DB
}
// NewUsersStore returns a persistent interface for users with given database
// connection.
func NewUsersStore(db *gorm.DB) UsersStore {
return &users{DB: db}
return &usersStore{DB: db}
}
type ErrLoginSourceMismatch struct {
@ -173,10 +173,10 @@ func (err ErrLoginSourceMismatch) Error() string {
return fmt.Sprintf("login source mismatch: %v", err.args)
}
func (db *users) Authenticate(ctx context.Context, login, password string, loginSourceID int64) (*User, error) {
func (s *usersStore) Authenticate(ctx context.Context, login, password string, loginSourceID int64) (*User, error) {
login = strings.ToLower(login)
query := db.WithContext(ctx)
query := s.WithContext(ctx)
if strings.Contains(login, "@") {
query = query.Where("email = ?", login)
} else {
@ -244,7 +244,7 @@ func (db *users) Authenticate(ctx context.Context, login, password string, login
return nil, fmt.Errorf("invalid pattern for attribute 'username' [%s]: must be valid alpha or numeric or dash(-_) or dot characters", extAccount.Name)
}
return db.Create(ctx, extAccount.Name, extAccount.Email,
return s.Create(ctx, extAccount.Name, extAccount.Email,
CreateUserOptions{
FullName: extAccount.FullName,
LoginSource: authSourceID,
@ -257,13 +257,13 @@ func (db *users) Authenticate(ctx context.Context, login, password string, login
)
}
func (db *users) ChangeUsername(ctx context.Context, userID int64, newUsername string) error {
func (s *usersStore) ChangeUsername(ctx context.Context, userID int64, newUsername string) error {
err := isUsernameAllowed(newUsername)
if err != nil {
return err
}
if db.IsUsernameUsed(ctx, newUsername, userID) {
if s.IsUsernameUsed(ctx, newUsername, userID) {
return ErrUserAlreadyExist{
args: errutil.Args{
"name": newUsername,
@ -271,12 +271,12 @@ func (db *users) ChangeUsername(ctx context.Context, userID int64, newUsername s
}
}
user, err := db.GetByID(ctx, userID)
user, err := s.GetByID(ctx, userID)
if err != nil {
return errors.Wrap(err, "get user")
}
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
err := tx.Model(&User{}).
Where("id = ?", user.ID).
Updates(map[string]any{
@ -338,9 +338,9 @@ func (db *users) ChangeUsername(ctx context.Context, userID int64, newUsername s
})
}
func (db *users) Count(ctx context.Context) int64 {
func (s *usersStore) Count(ctx context.Context) int64 {
var count int64
db.WithContext(ctx).Model(&User{}).Where("type = ?", UserTypeIndividual).Count(&count)
s.WithContext(ctx).Model(&User{}).Where("type = ?", UserTypeIndividual).Count(&count)
return count
}
@ -393,13 +393,13 @@ func (err ErrEmailAlreadyUsed) Error() string {
return fmt.Sprintf("email has been used: %v", err.args)
}
func (db *users) Create(ctx context.Context, username, email string, opts CreateUserOptions) (*User, error) {
func (s *usersStore) Create(ctx context.Context, username, email string, opts CreateUserOptions) (*User, error) {
err := isUsernameAllowed(username)
if err != nil {
return nil, err
}
if db.IsUsernameUsed(ctx, username, 0) {
if s.IsUsernameUsed(ctx, username, 0) {
return nil, ErrUserAlreadyExist{
args: errutil.Args{
"name": username,
@ -408,7 +408,7 @@ func (db *users) Create(ctx context.Context, username, email string, opts Create
}
email = strings.ToLower(strings.TrimSpace(email))
_, err = db.GetByEmail(ctx, email)
_, err = s.GetByEmail(ctx, email)
if err == nil {
return nil, ErrEmailAlreadyUsed{
args: errutil.Args{
@ -446,17 +446,17 @@ func (db *users) Create(ctx context.Context, username, email string, opts Create
}
user.Password = userutil.EncodePassword(user.Password, user.Salt)
return user, db.WithContext(ctx).Create(user).Error
return user, s.WithContext(ctx).Create(user).Error
}
func (db *users) DeleteCustomAvatar(ctx context.Context, userID int64) error {
func (s *usersStore) DeleteCustomAvatar(ctx context.Context, userID int64) error {
_ = os.Remove(userutil.CustomAvatarPath(userID))
return db.WithContext(ctx).
return s.WithContext(ctx).
Model(&User{}).
Where("id = ?", userID).
Updates(map[string]any{
"use_custom_avatar": false,
"updated_unix": db.NowFunc().Unix(),
"updated_unix": s.NowFunc().Unix(),
}).
Error
}
@ -491,8 +491,8 @@ func (err ErrUserHasOrgs) Error() string {
return fmt.Sprintf("user still has organization membership: %v", err.args)
}
func (db *users) DeleteByID(ctx context.Context, userID int64, skipRewriteAuthorizedKeys bool) error {
user, err := db.GetByID(ctx, userID)
func (s *usersStore) DeleteByID(ctx context.Context, userID int64, skipRewriteAuthorizedKeys bool) error {
user, err := s.GetByID(ctx, userID)
if err != nil {
if IsErrUserNotExist(err) {
return nil
@ -503,14 +503,14 @@ func (db *users) DeleteByID(ctx context.Context, userID int64, skipRewriteAuthor
// Double check the user is not a direct owner of any repository and not a
// member of any organization.
var count int64
err = db.WithContext(ctx).Model(&Repository{}).Where("owner_id = ?", userID).Count(&count).Error
err = s.WithContext(ctx).Model(&Repository{}).Where("owner_id = ?", userID).Count(&count).Error
if err != nil {
return errors.Wrap(err, "count repositories")
} else if count > 0 {
return ErrUserOwnRepos{args: errutil.Args{"userID": userID}}
}
err = db.WithContext(ctx).Model(&OrgUser{}).Where("uid = ?", userID).Count(&count).Error
err = s.WithContext(ctx).Model(&OrgUser{}).Where("uid = ?", userID).Count(&count).Error
if err != nil {
return errors.Wrap(err, "count organization membership")
} else if count > 0 {
@ -518,7 +518,7 @@ func (db *users) DeleteByID(ctx context.Context, userID int64, skipRewriteAuthor
}
needsRewriteAuthorizedKeys := false
err = db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
err = s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
/*
Equivalent SQL for PostgreSQL:
@ -645,7 +645,7 @@ func (db *users) DeleteByID(ctx context.Context, userID int64, skipRewriteAuthor
_ = os.Remove(userutil.CustomAvatarPath(userID))
if needsRewriteAuthorizedKeys {
err = NewPublicKeysStore(db.DB).RewriteAuthorizedKeys()
err = NewPublicKeysStore(s.DB).RewriteAuthorizedKeys()
if err != nil {
return errors.Wrap(err, `rewrite "authorized_keys" file`)
}
@ -655,15 +655,15 @@ func (db *users) DeleteByID(ctx context.Context, userID int64, skipRewriteAuthor
// NOTE: We do not take context.Context here because this operation in practice
// could much longer than the general request timeout (e.g. one minute).
func (db *users) DeleteInactivated() error {
func (s *usersStore) DeleteInactivated() error {
var userIDs []int64
err := db.Model(&User{}).Where("is_active = ?", false).Pluck("id", &userIDs).Error
err := s.Model(&User{}).Where("is_active = ?", false).Pluck("id", &userIDs).Error
if err != nil {
return errors.Wrap(err, "get inactivated user IDs")
}
for _, userID := range userIDs {
err = db.DeleteByID(context.Background(), userID, true)
err = s.DeleteByID(context.Background(), userID, true)
if err != nil {
// Skip users that may had set to inactivated by admins.
if IsErrUserOwnRepos(err) || IsErrUserHasOrgs(err) {
@ -672,14 +672,14 @@ func (db *users) DeleteInactivated() error {
return errors.Wrapf(err, "delete user with ID %d", userID)
}
}
err = NewPublicKeysStore(db.DB).RewriteAuthorizedKeys()
err = NewPublicKeysStore(s.DB).RewriteAuthorizedKeys()
if err != nil {
return errors.Wrap(err, `rewrite "authorized_keys" file`)
}
return nil
}
func (*users) recountFollows(tx *gorm.DB, userID, followID int64) error {
func (*usersStore) recountFollows(tx *gorm.DB, userID, followID int64) error {
/*
Equivalent SQL for PostgreSQL:
@ -722,12 +722,12 @@ func (*users) recountFollows(tx *gorm.DB, userID, followID int64) error {
return nil
}
func (db *users) Follow(ctx context.Context, userID, followID int64) error {
func (s *usersStore) Follow(ctx context.Context, userID, followID int64) error {
if userID == followID {
return nil
}
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
f := &Follow{
UserID: userID,
FollowID: followID,
@ -739,26 +739,26 @@ func (db *users) Follow(ctx context.Context, userID, followID int64) error {
return nil // Relation already exists
}
return db.recountFollows(tx, userID, followID)
return s.recountFollows(tx, userID, followID)
})
}
func (db *users) Unfollow(ctx context.Context, userID, followID int64) error {
func (s *usersStore) Unfollow(ctx context.Context, userID, followID int64) error {
if userID == followID {
return nil
}
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
err := tx.Where("user_id = ? AND follow_id = ?", userID, followID).Delete(&Follow{}).Error
if err != nil {
return errors.Wrap(err, "delete")
}
return db.recountFollows(tx, userID, followID)
return s.recountFollows(tx, userID, followID)
})
}
func (db *users) IsFollowing(ctx context.Context, userID, followID int64) bool {
return db.WithContext(ctx).Where("user_id = ? AND follow_id = ?", userID, followID).First(&Follow{}).Error == nil
func (s *usersStore) IsFollowing(ctx context.Context, userID, followID int64) bool {
return s.WithContext(ctx).Where("user_id = ? AND follow_id = ?", userID, followID).First(&Follow{}).Error == nil
}
var _ errutil.NotFound = (*ErrUserNotExist)(nil)
@ -782,7 +782,7 @@ func (ErrUserNotExist) NotFound() bool {
return true
}
func (db *users) GetByEmail(ctx context.Context, email string) (*User, error) {
func (s *usersStore) GetByEmail(ctx context.Context, email string) (*User, error) {
if email == "" {
return nil, ErrUserNotExist{args: errutil.Args{"email": email}}
}
@ -801,10 +801,10 @@ func (db *users) GetByEmail(ctx context.Context, email string) (*User, error) {
)
*/
user := new(User)
err := db.WithContext(ctx).
err := s.WithContext(ctx).
Joins(dbutil.Quote("LEFT JOIN email_address ON email_address.uid = %s.id", "user"), true).
Where(dbutil.Quote("%s.type = ?", "user"), UserTypeIndividual).
Where(db.
Where(s.
Where(dbutil.Quote("%[1]s.email = ? AND %[1]s.is_active = ?", "user"), email, true).
Or("email_address.email = ? AND email_address.is_activated = ?", email, true),
).
@ -819,9 +819,9 @@ func (db *users) GetByEmail(ctx context.Context, email string) (*User, error) {
return user, nil
}
func (db *users) GetByID(ctx context.Context, id int64) (*User, error) {
func (s *usersStore) GetByID(ctx context.Context, id int64) (*User, error) {
user := new(User)
err := db.WithContext(ctx).Where("id = ?", id).First(user).Error
err := s.WithContext(ctx).Where("id = ?", id).First(user).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, ErrUserNotExist{args: errutil.Args{"userID": id}}
@ -831,9 +831,9 @@ func (db *users) GetByID(ctx context.Context, id int64) (*User, error) {
return user, nil
}
func (db *users) GetByUsername(ctx context.Context, username string) (*User, error) {
func (s *usersStore) GetByUsername(ctx context.Context, username string) (*User, error) {
user := new(User)
err := db.WithContext(ctx).Where("lower_name = ?", strings.ToLower(username)).First(user).Error
err := s.WithContext(ctx).Where("lower_name = ?", strings.ToLower(username)).First(user).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, ErrUserNotExist{args: errutil.Args{"name": username}}
@ -843,9 +843,9 @@ func (db *users) GetByUsername(ctx context.Context, username string) (*User, err
return user, nil
}
func (db *users) GetByKeyID(ctx context.Context, keyID int64) (*User, error) {
func (s *usersStore) GetByKeyID(ctx context.Context, keyID int64) (*User, error) {
user := new(User)
err := db.WithContext(ctx).
err := s.WithContext(ctx).
Joins(dbutil.Quote("JOIN public_key ON public_key.owner_id = %s.id", "user")).
Where("public_key.id = ?", keyID).
First(user).
@ -859,29 +859,29 @@ func (db *users) GetByKeyID(ctx context.Context, keyID int64) (*User, error) {
return user, nil
}
func (db *users) GetMailableEmailsByUsernames(ctx context.Context, usernames []string) ([]string, error) {
func (s *usersStore) GetMailableEmailsByUsernames(ctx context.Context, usernames []string) ([]string, error) {
emails := make([]string, 0, len(usernames))
return emails, db.WithContext(ctx).
return emails, s.WithContext(ctx).
Model(&User{}).
Select("email").
Where("lower_name IN (?) AND is_active = ?", usernames, true).
Find(&emails).Error
}
func (db *users) IsUsernameUsed(ctx context.Context, username string, excludeUserId int64) bool {
func (s *usersStore) IsUsernameUsed(ctx context.Context, username string, excludeUserId int64) bool {
if username == "" {
return false
}
return db.WithContext(ctx).
return s.WithContext(ctx).
Select("id").
Where("lower_name = ? AND id != ?", strings.ToLower(username), excludeUserId).
First(&User{}).
Error != gorm.ErrRecordNotFound
}
func (db *users) List(ctx context.Context, page, pageSize int) ([]*User, error) {
func (s *usersStore) List(ctx context.Context, page, pageSize int) ([]*User, error) {
users := make([]*User, 0, pageSize)
return users, db.WithContext(ctx).
return users, s.WithContext(ctx).
Where("type = ?", UserTypeIndividual).
Limit(pageSize).Offset((page - 1) * pageSize).
Order("id ASC").
@ -889,7 +889,7 @@ func (db *users) List(ctx context.Context, page, pageSize int) ([]*User, error)
Error
}
func (db *users) ListFollowers(ctx context.Context, userID int64, page, pageSize int) ([]*User, error) {
func (s *usersStore) ListFollowers(ctx context.Context, userID int64, page, pageSize int) ([]*User, error) {
/*
Equivalent SQL for PostgreSQL:
@ -900,7 +900,7 @@ func (db *users) ListFollowers(ctx context.Context, userID int64, page, pageSize
LIMIT @limit OFFSET @offset
*/
users := make([]*User, 0, pageSize)
return users, db.WithContext(ctx).
return users, s.WithContext(ctx).
Joins(dbutil.Quote("LEFT JOIN follow ON follow.user_id = %s.id", "user")).
Where("follow.follow_id = ?", userID).
Limit(pageSize).Offset((page - 1) * pageSize).
@ -909,7 +909,7 @@ func (db *users) ListFollowers(ctx context.Context, userID int64, page, pageSize
Error
}
func (db *users) ListFollowings(ctx context.Context, userID int64, page, pageSize int) ([]*User, error) {
func (s *usersStore) ListFollowings(ctx context.Context, userID int64, page, pageSize int) ([]*User, error) {
/*
Equivalent SQL for PostgreSQL:
@ -920,7 +920,7 @@ func (db *users) ListFollowings(ctx context.Context, userID int64, page, pageSiz
LIMIT @limit OFFSET @offset
*/
users := make([]*User, 0, pageSize)
return users, db.WithContext(ctx).
return users, s.WithContext(ctx).
Joins(dbutil.Quote("LEFT JOIN follow ON follow.follow_id = %s.id", "user")).
Where("follow.user_id = ?", userID).
Limit(pageSize).Offset((page - 1) * pageSize).
@ -948,8 +948,8 @@ func searchUserByName(ctx context.Context, db *gorm.DB, userType UserType, keywo
return users, count, tx.Order(orderBy).Limit(pageSize).Offset((page - 1) * pageSize).Find(&users).Error
}
func (db *users) SearchByName(ctx context.Context, keyword string, page, pageSize int, orderBy string) ([]*User, int64, error) {
return searchUserByName(ctx, db.DB, UserTypeIndividual, keyword, page, pageSize, orderBy)
func (s *usersStore) SearchByName(ctx context.Context, keyword string, page, pageSize int, orderBy string) ([]*User, int64, error) {
return searchUserByName(ctx, s.DB, UserTypeIndividual, keyword, page, pageSize, orderBy)
}
type UpdateUserOptions struct {
@ -979,9 +979,9 @@ type UpdateUserOptions struct {
AvatarEmail *string
}
func (db *users) Update(ctx context.Context, userID int64, opts UpdateUserOptions) error {
func (s *usersStore) Update(ctx context.Context, userID int64, opts UpdateUserOptions) error {
updates := map[string]any{
"updated_unix": db.NowFunc().Unix(),
"updated_unix": s.NowFunc().Unix(),
}
if opts.LoginSource != nil {
@ -1012,7 +1012,7 @@ func (db *users) Update(ctx context.Context, userID int64, opts UpdateUserOption
updates["full_name"] = strutil.Truncate(*opts.FullName, 255)
}
if opts.Email != nil {
_, err := db.GetByEmail(ctx, *opts.Email)
_, err := s.GetByEmail(ctx, *opts.Email)
if err == nil {
return ErrEmailAlreadyUsed{args: errutil.Args{"email": *opts.Email}}
} else if !IsErrUserNotExist(err) {
@ -1063,28 +1063,28 @@ func (db *users) Update(ctx context.Context, userID int64, opts UpdateUserOption
updates["avatar_email"] = strutil.Truncate(*opts.AvatarEmail, 255)
}
return db.WithContext(ctx).Model(&User{}).Where("id = ?", userID).Updates(updates).Error
return s.WithContext(ctx).Model(&User{}).Where("id = ?", userID).Updates(updates).Error
}
func (db *users) UseCustomAvatar(ctx context.Context, userID int64, avatar []byte) error {
func (s *usersStore) UseCustomAvatar(ctx context.Context, userID int64, avatar []byte) error {
err := userutil.SaveAvatar(userID, avatar)
if err != nil {
return errors.Wrap(err, "save avatar")
}
return db.WithContext(ctx).
return s.WithContext(ctx).
Model(&User{}).
Where("id = ?", userID).
Updates(map[string]any{
"use_custom_avatar": true,
"updated_unix": db.NowFunc().Unix(),
"updated_unix": s.NowFunc().Unix(),
}).
Error
}
func (db *users) AddEmail(ctx context.Context, userID int64, email string, isActivated bool) error {
func (s *usersStore) AddEmail(ctx context.Context, userID int64, email string, isActivated bool) error {
email = strings.ToLower(strings.TrimSpace(email))
_, err := db.GetByEmail(ctx, email)
_, err := s.GetByEmail(ctx, email)
if err == nil {
return ErrEmailAlreadyUsed{
args: errutil.Args{
@ -1095,7 +1095,7 @@ func (db *users) AddEmail(ctx context.Context, userID int64, email string, isAct
return errors.Wrap(err, "check user by email")
}
return db.WithContext(ctx).Create(
return s.WithContext(ctx).Create(
&EmailAddress{
UserID: userID,
Email: email,
@ -1125,8 +1125,8 @@ func (ErrEmailNotExist) NotFound() bool {
return true
}
func (db *users) GetEmail(ctx context.Context, userID int64, email string, needsActivated bool) (*EmailAddress, error) {
tx := db.WithContext(ctx).Where("uid = ? AND email = ?", userID, email)
func (s *usersStore) GetEmail(ctx context.Context, userID int64, email string, needsActivated bool) (*EmailAddress, error) {
tx := s.WithContext(ctx).Where("uid = ? AND email = ?", userID, email)
if needsActivated {
tx = tx.Where("is_activated = ?", true)
}
@ -1146,14 +1146,14 @@ func (db *users) GetEmail(ctx context.Context, userID int64, email string, needs
return emailAddress, nil
}
func (db *users) ListEmails(ctx context.Context, userID int64) ([]*EmailAddress, error) {
user, err := db.GetByID(ctx, userID)
func (s *usersStore) ListEmails(ctx context.Context, userID int64) ([]*EmailAddress, error) {
user, err := s.GetByID(ctx, userID)
if err != nil {
return nil, errors.Wrap(err, "get user")
}
var emails []*EmailAddress
err = db.WithContext(ctx).Where("uid = ?", userID).Order("id ASC").Find(&emails).Error
err = s.WithContext(ctx).Where("uid = ?", userID).Order("id ASC").Find(&emails).Error
if err != nil {
return nil, errors.Wrap(err, "list emails")
}
@ -1179,9 +1179,9 @@ func (db *users) ListEmails(ctx context.Context, userID int64) ([]*EmailAddress,
return emails, nil
}
func (db *users) MarkEmailActivated(ctx context.Context, userID int64, email string) error {
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
err := db.WithContext(ctx).
func (s *usersStore) MarkEmailActivated(ctx context.Context, userID int64, email string) error {
return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
err := s.WithContext(ctx).
Model(&EmailAddress{}).
Where("uid = ? AND email = ?", userID, email).
Update("is_activated", true).
@ -1209,9 +1209,9 @@ func (err ErrEmailNotVerified) Error() string {
return fmt.Sprintf("email has not been verified: %v", err.args)
}
func (db *users) MarkEmailPrimary(ctx context.Context, userID int64, email string) error {
func (s *usersStore) MarkEmailPrimary(ctx context.Context, userID int64, email string) error {
var emailAddress EmailAddress
err := db.WithContext(ctx).Where("uid = ? AND email = ?", userID, email).First(&emailAddress).Error
err := s.WithContext(ctx).Where("uid = ? AND email = ?", userID, email).First(&emailAddress).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return ErrEmailNotExist{args: errutil.Args{"email": email}}
@ -1223,12 +1223,12 @@ func (db *users) MarkEmailPrimary(ctx context.Context, userID int64, email strin
return ErrEmailNotVerified{args: errutil.Args{"email": email}}
}
user, err := db.GetByID(ctx, userID)
user, err := s.GetByID(ctx, userID)
if err != nil {
return errors.Wrap(err, "get user")
}
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// Make sure the former primary email doesn't disappear.
err = tx.FirstOrCreate(
&EmailAddress{
@ -1255,8 +1255,8 @@ func (db *users) MarkEmailPrimary(ctx context.Context, userID int64, email strin
})
}
func (db *users) DeleteEmail(ctx context.Context, userID int64, email string) error {
return db.WithContext(ctx).Where("uid = ? AND email = ?", userID, email).Delete(&EmailAddress{}).Error
func (s *usersStore) DeleteEmail(ctx context.Context, userID int64, email string) error {
return s.WithContext(ctx).Where("uid = ? AND email = ?", userID, email).Delete(&EmailAddress{}).Error
}
// UserType indicates the type of the user account.

View File

@ -84,13 +84,13 @@ func TestUsers(t *testing.T) {
t.Parallel()
ctx := context.Background()
db := &users{
DB: newTestDB(t, "users"),
db := &usersStore{
DB: newTestDB(t, "usersStore"),
}
for _, tc := range []struct {
name string
test func(t *testing.T, ctx context.Context, db *users)
test func(t *testing.T, ctx context.Context, db *usersStore)
}{
{"Authenticate", usersAuthenticate},
{"ChangeUsername", usersChangeUsername},
@ -134,7 +134,7 @@ func TestUsers(t *testing.T) {
}
}
func usersAuthenticate(t *testing.T, ctx context.Context, db *users) {
func usersAuthenticate(t *testing.T, ctx context.Context, db *usersStore) {
password := "pa$$word"
alice, err := db.Create(ctx, "alice", "alice@example.com",
CreateUserOptions{
@ -229,7 +229,7 @@ func usersAuthenticate(t *testing.T, ctx context.Context, db *users) {
})
}
func usersChangeUsername(t *testing.T, ctx context.Context, db *users) {
func usersChangeUsername(t *testing.T, ctx context.Context, db *usersStore) {
alice, err := db.Create(
ctx,
"alice",
@ -359,7 +359,7 @@ func usersChangeUsername(t *testing.T, ctx context.Context, db *users) {
assert.Equal(t, strings.ToUpper(newUsername), alice.Name)
}
func usersCount(t *testing.T, ctx context.Context, db *users) {
func usersCount(t *testing.T, ctx context.Context, db *usersStore) {
// Has no user initially
got := db.Count(ctx)
assert.Equal(t, int64(0), got)
@ -382,7 +382,7 @@ func usersCount(t *testing.T, ctx context.Context, db *users) {
assert.Equal(t, int64(1), got)
}
func usersCreate(t *testing.T, ctx context.Context, db *users) {
func usersCreate(t *testing.T, ctx context.Context, db *usersStore) {
alice, err := db.Create(
ctx,
"alice",
@ -430,7 +430,7 @@ func usersCreate(t *testing.T, ctx context.Context, db *users) {
assert.Equal(t, db.NowFunc().Format(time.RFC3339), user.Updated.UTC().Format(time.RFC3339))
}
func usersDeleteCustomAvatar(t *testing.T, ctx context.Context, db *users) {
func usersDeleteCustomAvatar(t *testing.T, ctx context.Context, db *usersStore) {
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
@ -464,7 +464,7 @@ func usersDeleteCustomAvatar(t *testing.T, ctx context.Context, db *users) {
assert.False(t, alice.UseCustomAvatar)
}
func usersDeleteByID(t *testing.T, ctx context.Context, db *users) {
func usersDeleteByID(t *testing.T, ctx context.Context, db *usersStore) {
reposStore := NewReposStore(db.DB)
t.Run("user still has repository ownership", func(t *testing.T) {
@ -674,7 +674,7 @@ func usersDeleteByID(t *testing.T, ctx context.Context, db *users) {
assert.Equal(t, wantErr, err)
}
func usersDeleteInactivated(t *testing.T, ctx context.Context, db *users) {
func usersDeleteInactivated(t *testing.T, ctx context.Context, db *usersStore) {
// User with repository ownership should be skipped
alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{})
require.NoError(t, err)
@ -720,7 +720,7 @@ func usersDeleteInactivated(t *testing.T, ctx context.Context, db *users) {
require.Len(t, users, 3)
}
func usersGetByEmail(t *testing.T, ctx context.Context, db *users) {
func usersGetByEmail(t *testing.T, ctx context.Context, db *usersStore) {
t.Run("empty email", func(t *testing.T) {
_, err := db.GetByEmail(ctx, "")
wantErr := ErrUserNotExist{args: errutil.Args{"email": ""}}
@ -781,7 +781,7 @@ func usersGetByEmail(t *testing.T, ctx context.Context, db *users) {
})
}
func usersGetByID(t *testing.T, ctx context.Context, db *users) {
func usersGetByID(t *testing.T, ctx context.Context, db *usersStore) {
alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{})
require.NoError(t, err)
@ -794,7 +794,7 @@ func usersGetByID(t *testing.T, ctx context.Context, db *users) {
assert.Equal(t, wantErr, err)
}
func usersGetByUsername(t *testing.T, ctx context.Context, db *users) {
func usersGetByUsername(t *testing.T, ctx context.Context, db *usersStore) {
alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{})
require.NoError(t, err)
@ -807,7 +807,7 @@ func usersGetByUsername(t *testing.T, ctx context.Context, db *users) {
assert.Equal(t, wantErr, err)
}
func usersGetByKeyID(t *testing.T, ctx context.Context, db *users) {
func usersGetByKeyID(t *testing.T, ctx context.Context, db *usersStore) {
alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{})
require.NoError(t, err)
@ -832,7 +832,7 @@ func usersGetByKeyID(t *testing.T, ctx context.Context, db *users) {
assert.Equal(t, wantErr, err)
}
func usersGetMailableEmailsByUsernames(t *testing.T, ctx context.Context, db *users) {
func usersGetMailableEmailsByUsernames(t *testing.T, ctx context.Context, db *usersStore) {
alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{})
require.NoError(t, err)
bob, err := db.Create(ctx, "bob", "bob@exmaple.com", CreateUserOptions{Activated: true})
@ -846,7 +846,7 @@ func usersGetMailableEmailsByUsernames(t *testing.T, ctx context.Context, db *us
assert.Equal(t, want, got)
}
func usersIsUsernameUsed(t *testing.T, ctx context.Context, db *users) {
func usersIsUsernameUsed(t *testing.T, ctx context.Context, db *usersStore) {
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
@ -896,7 +896,7 @@ func usersIsUsernameUsed(t *testing.T, ctx context.Context, db *users) {
}
}
func usersList(t *testing.T, ctx context.Context, db *users) {
func usersList(t *testing.T, ctx context.Context, db *usersStore) {
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
bob, err := db.Create(ctx, "bob", "bob@example.com", CreateUserOptions{})
@ -929,7 +929,7 @@ func usersList(t *testing.T, ctx context.Context, db *users) {
assert.Equal(t, bob.ID, got[1].ID)
}
func usersListFollowers(t *testing.T, ctx context.Context, db *users) {
func usersListFollowers(t *testing.T, ctx context.Context, db *usersStore) {
john, err := db.Create(ctx, "john", "john@example.com", CreateUserOptions{})
require.NoError(t, err)
@ -960,7 +960,7 @@ func usersListFollowers(t *testing.T, ctx context.Context, db *users) {
assert.Equal(t, alice.ID, got[0].ID)
}
func usersListFollowings(t *testing.T, ctx context.Context, db *users) {
func usersListFollowings(t *testing.T, ctx context.Context, db *usersStore) {
john, err := db.Create(ctx, "john", "john@example.com", CreateUserOptions{})
require.NoError(t, err)
@ -991,7 +991,7 @@ func usersListFollowings(t *testing.T, ctx context.Context, db *users) {
assert.Equal(t, alice.ID, got[0].ID)
}
func usersSearchByName(t *testing.T, ctx context.Context, db *users) {
func usersSearchByName(t *testing.T, ctx context.Context, db *usersStore) {
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{FullName: "Alice Jordan"})
require.NoError(t, err)
bob, err := db.Create(ctx, "bob", "bob@example.com", CreateUserOptions{FullName: "Bob Jordan"})
@ -1029,7 +1029,7 @@ func usersSearchByName(t *testing.T, ctx context.Context, db *users) {
})
}
func usersUpdate(t *testing.T, ctx context.Context, db *users) {
func usersUpdate(t *testing.T, ctx context.Context, db *usersStore) {
const oldPassword = "Password"
alice, err := db.Create(
ctx,
@ -1142,7 +1142,7 @@ func usersUpdate(t *testing.T, ctx context.Context, db *users) {
assertValues()
}
func usersUseCustomAvatar(t *testing.T, ctx context.Context, db *users) {
func usersUseCustomAvatar(t *testing.T, ctx context.Context, db *usersStore) {
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
@ -1180,7 +1180,7 @@ func TestIsUsernameAllowed(t *testing.T) {
}
}
func usersAddEmail(t *testing.T, ctx context.Context, db *users) {
func usersAddEmail(t *testing.T, ctx context.Context, db *usersStore) {
t.Run("multiple users can add the same unverified email", func(t *testing.T) {
alice, err := db.Create(ctx, "alice", "unverified@example.com", CreateUserOptions{})
require.NoError(t, err)
@ -1197,7 +1197,7 @@ func usersAddEmail(t *testing.T, ctx context.Context, db *users) {
})
}
func usersGetEmail(t *testing.T, ctx context.Context, db *users) {
func usersGetEmail(t *testing.T, ctx context.Context, db *usersStore) {
const testUserID = 1
const testEmail = "alice@example.com"
_, err := db.GetEmail(ctx, testUserID, testEmail, false)
@ -1229,7 +1229,7 @@ func usersGetEmail(t *testing.T, ctx context.Context, db *users) {
assert.Equal(t, testEmail, got.Email)
}
func usersListEmails(t *testing.T, ctx context.Context, db *users) {
func usersListEmails(t *testing.T, ctx context.Context, db *usersStore) {
t.Run("list emails with primary email", func(t *testing.T) {
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
@ -1265,7 +1265,7 @@ func usersListEmails(t *testing.T, ctx context.Context, db *users) {
})
}
func usersMarkEmailActivated(t *testing.T, ctx context.Context, db *users) {
func usersMarkEmailActivated(t *testing.T, ctx context.Context, db *usersStore) {
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
@ -1283,7 +1283,7 @@ func usersMarkEmailActivated(t *testing.T, ctx context.Context, db *users) {
assert.NotEqual(t, alice.Rands, gotAlice.Rands)
}
func usersMarkEmailPrimary(t *testing.T, ctx context.Context, db *users) {
func usersMarkEmailPrimary(t *testing.T, ctx context.Context, db *usersStore) {
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
err = db.AddEmail(ctx, alice.ID, "alice2@example.com", false)
@ -1309,7 +1309,7 @@ func usersMarkEmailPrimary(t *testing.T, ctx context.Context, db *users) {
assert.False(t, gotEmail.IsActivated)
}
func usersDeleteEmail(t *testing.T, ctx context.Context, db *users) {
func usersDeleteEmail(t *testing.T, ctx context.Context, db *usersStore) {
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
@ -1325,7 +1325,7 @@ func usersDeleteEmail(t *testing.T, ctx context.Context, db *users) {
require.Equal(t, want, got)
}
func usersFollow(t *testing.T, ctx context.Context, db *users) {
func usersFollow(t *testing.T, ctx context.Context, db *usersStore) {
usersStore := NewUsersStore(db.DB)
alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
@ -1348,7 +1348,7 @@ func usersFollow(t *testing.T, ctx context.Context, db *users) {
assert.Equal(t, 1, bob.NumFollowers)
}
func usersIsFollowing(t *testing.T, ctx context.Context, db *users) {
func usersIsFollowing(t *testing.T, ctx context.Context, db *usersStore) {
usersStore := NewUsersStore(db.DB)
alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
@ -1369,7 +1369,7 @@ func usersIsFollowing(t *testing.T, ctx context.Context, db *users) {
assert.False(t, got)
}
func usersUnfollow(t *testing.T, ctx context.Context, db *users) {
func usersUnfollow(t *testing.T, ctx context.Context, db *usersStore) {
usersStore := NewUsersStore(db.DB)
alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)