diff --git a/internal/database/access_tokens.go b/internal/database/access_tokens.go index f097e3053..3121bf130 100644 --- a/internal/database/access_tokens.go +++ b/internal/database/access_tokens.go @@ -74,9 +74,9 @@ func (t *AccessToken) AfterFind(tx *gorm.DB) error { return nil } -var _ AccessTokensStore = (*accessTokens)(nil) +var _ AccessTokensStore = (*accessTokensStore)(nil) -type accessTokens struct { +type accessTokensStore struct { *gorm.DB } @@ -93,8 +93,8 @@ func (err ErrAccessTokenAlreadyExist) Error() string { return fmt.Sprintf("access token already exists: %v", err.args) } -func (db *accessTokens) Create(ctx context.Context, userID int64, name string) (*AccessToken, error) { - err := db.WithContext(ctx).Where("uid = ? AND name = ?", userID, name).First(new(AccessToken)).Error +func (s *accessTokensStore) Create(ctx context.Context, userID int64, name string) (*AccessToken, error) { + err := s.WithContext(ctx).Where("uid = ? AND name = ?", userID, name).First(new(AccessToken)).Error if err == nil { return nil, ErrAccessTokenAlreadyExist{args: errutil.Args{"userID": userID, "name": name}} } else if err != gorm.ErrRecordNotFound { @@ -110,7 +110,7 @@ func (db *accessTokens) Create(ctx context.Context, userID int64, name string) ( Sha1: sha256[:40], // To pass the column unique constraint, keep the length of SHA1. SHA256: sha256, } - if err = db.WithContext(ctx).Create(accessToken).Error; err != nil { + if err = s.WithContext(ctx).Create(accessToken).Error; err != nil { return nil, err } @@ -119,8 +119,8 @@ func (db *accessTokens) Create(ctx context.Context, userID int64, name string) ( return accessToken, nil } -func (db *accessTokens) DeleteByID(ctx context.Context, userID, id int64) error { - return db.WithContext(ctx).Where("id = ? AND uid = ?", id, userID).Delete(new(AccessToken)).Error +func (s *accessTokensStore) DeleteByID(ctx context.Context, userID, id int64) error { + return s.WithContext(ctx).Where("id = ? AND uid = ?", id, userID).Delete(new(AccessToken)).Error } var _ errutil.NotFound = (*ErrAccessTokenNotExist)(nil) @@ -144,7 +144,7 @@ func (ErrAccessTokenNotExist) NotFound() bool { return true } -func (db *accessTokens) GetBySHA1(ctx context.Context, sha1 string) (*AccessToken, error) { +func (s *accessTokensStore) GetBySHA1(ctx context.Context, sha1 string) (*AccessToken, error) { // No need to waste a query for an empty SHA1. if sha1 == "" { return nil, ErrAccessTokenNotExist{args: errutil.Args{"sha": sha1}} @@ -152,7 +152,7 @@ func (db *accessTokens) GetBySHA1(ctx context.Context, sha1 string) (*AccessToke sha256 := cryptoutil.SHA256(sha1) token := new(AccessToken) - err := db.WithContext(ctx).Where("sha256 = ?", sha256).First(token).Error + err := s.WithContext(ctx).Where("sha256 = ?", sha256).First(token).Error if err != nil { if err == gorm.ErrRecordNotFound { return nil, ErrAccessTokenNotExist{args: errutil.Args{"sha": sha1}} @@ -162,15 +162,15 @@ func (db *accessTokens) GetBySHA1(ctx context.Context, sha1 string) (*AccessToke return token, nil } -func (db *accessTokens) List(ctx context.Context, userID int64) ([]*AccessToken, error) { +func (s *accessTokensStore) List(ctx context.Context, userID int64) ([]*AccessToken, error) { var tokens []*AccessToken - return tokens, db.WithContext(ctx).Where("uid = ?", userID).Order("id ASC").Find(&tokens).Error + return tokens, s.WithContext(ctx).Where("uid = ?", userID).Order("id ASC").Find(&tokens).Error } -func (db *accessTokens) Touch(ctx context.Context, id int64) error { - return db.WithContext(ctx). +func (s *accessTokensStore) Touch(ctx context.Context, id int64) error { + return s.WithContext(ctx). Model(new(AccessToken)). Where("id = ?", id). - UpdateColumn("updated_unix", db.NowFunc().Unix()). + UpdateColumn("updated_unix", s.NowFunc().Unix()). Error } diff --git a/internal/database/access_tokens_test.go b/internal/database/access_tokens_test.go index 72b68daf9..547120071 100644 --- a/internal/database/access_tokens_test.go +++ b/internal/database/access_tokens_test.go @@ -98,13 +98,13 @@ func TestAccessTokens(t *testing.T) { t.Parallel() ctx := context.Background() - db := &accessTokens{ - DB: newTestDB(t, "accessTokens"), + db := &accessTokensStore{ + DB: newTestDB(t, "accessTokensStore"), } for _, tc := range []struct { name string - test func(t *testing.T, ctx context.Context, db *accessTokens) + test func(t *testing.T, ctx context.Context, db *accessTokensStore) }{ {"Create", accessTokensCreate}, {"DeleteByID", accessTokensDeleteByID}, @@ -125,7 +125,7 @@ func TestAccessTokens(t *testing.T) { } } -func accessTokensCreate(t *testing.T, ctx context.Context, db *accessTokens) { +func accessTokensCreate(t *testing.T, ctx context.Context, db *accessTokensStore) { // Create first access token with name "Test" token, err := db.Create(ctx, 1, "Test") require.NoError(t, err) @@ -150,7 +150,7 @@ func accessTokensCreate(t *testing.T, ctx context.Context, db *accessTokens) { assert.Equal(t, wantErr, err) } -func accessTokensDeleteByID(t *testing.T, ctx context.Context, db *accessTokens) { +func accessTokensDeleteByID(t *testing.T, ctx context.Context, db *accessTokensStore) { // Create an access token with name "Test" token, err := db.Create(ctx, 1, "Test") require.NoError(t, err) @@ -177,7 +177,7 @@ func accessTokensDeleteByID(t *testing.T, ctx context.Context, db *accessTokens) assert.Equal(t, wantErr, err) } -func accessTokensGetBySHA(t *testing.T, ctx context.Context, db *accessTokens) { +func accessTokensGetBySHA(t *testing.T, ctx context.Context, db *accessTokensStore) { // Create an access token with name "Test" token, err := db.Create(ctx, 1, "Test") require.NoError(t, err) @@ -196,7 +196,7 @@ func accessTokensGetBySHA(t *testing.T, ctx context.Context, db *accessTokens) { assert.Equal(t, wantErr, err) } -func accessTokensList(t *testing.T, ctx context.Context, db *accessTokens) { +func accessTokensList(t *testing.T, ctx context.Context, db *accessTokensStore) { // Create two access tokens for user 1 _, err := db.Create(ctx, 1, "user1_1") require.NoError(t, err) @@ -219,7 +219,7 @@ func accessTokensList(t *testing.T, ctx context.Context, db *accessTokens) { assert.Equal(t, "user1_2", tokens[1].Name) } -func accessTokensTouch(t *testing.T, ctx context.Context, db *accessTokens) { +func accessTokensTouch(t *testing.T, ctx context.Context, db *accessTokensStore) { // Create an access token with name "Test" token, err := db.Create(ctx, 1, "Test") require.NoError(t, err) diff --git a/internal/database/actions.go b/internal/database/actions.go index 0496b7d4f..a678fdd23 100644 --- a/internal/database/actions.go +++ b/internal/database/actions.go @@ -70,19 +70,19 @@ type ActionsStore interface { var Actions ActionsStore -var _ ActionsStore = (*actions)(nil) +var _ ActionsStore = (*actionsStore)(nil) -type actions struct { +type actionsStore struct { *gorm.DB } // NewActionsStore returns a persistent interface for actions with given // database connection. func NewActionsStore(db *gorm.DB) ActionsStore { - return &actions{DB: db} + return &actionsStore{DB: db} } -func (db *actions) listByOrganization(ctx context.Context, orgID, actorID, afterID int64) *gorm.DB { +func (s *actionsStore) listByOrganization(ctx context.Context, orgID, actorID, afterID int64) *gorm.DB { /* Equivalent SQL for PostgreSQL: @@ -102,18 +102,18 @@ func (db *actions) listByOrganization(ctx context.Context, orgID, actorID, after ORDER BY id DESC LIMIT @limit */ - return db.WithContext(ctx). + return s.WithContext(ctx). Where("user_id = ?", orgID). - Where(db. + Where(s. // Not apply when afterID is not given Where("?", afterID <= 0). Or("id < ?", afterID), ). - Where("repo_id IN (?)", db. + Where("repo_id IN (?)", s. Select("repository.id"). Table("repository"). Joins("JOIN team_repo ON repository.id = team_repo.repo_id"). - Where("team_repo.team_id IN (?)", db. + Where("team_repo.team_id IN (?)", s. Select("team_id"). Table("team_user"). Where("team_user.org_id = ? AND uid = ?", orgID, actorID), @@ -124,12 +124,12 @@ func (db *actions) listByOrganization(ctx context.Context, orgID, actorID, after Order("id DESC") } -func (db *actions) ListByOrganization(ctx context.Context, orgID, actorID, afterID int64) ([]*Action, error) { +func (s *actionsStore) ListByOrganization(ctx context.Context, orgID, actorID, afterID int64) ([]*Action, error) { actions := make([]*Action, 0, conf.UI.User.NewsFeedPagingNum) - return actions, db.listByOrganization(ctx, orgID, actorID, afterID).Find(&actions).Error + return actions, s.listByOrganization(ctx, orgID, actorID, afterID).Find(&actions).Error } -func (db *actions) listByUser(ctx context.Context, userID, actorID, afterID int64, isProfile bool) *gorm.DB { +func (s *actionsStore) listByUser(ctx context.Context, userID, actorID, afterID int64, isProfile bool) *gorm.DB { /* Equivalent SQL for PostgreSQL: @@ -141,14 +141,14 @@ func (db *actions) listByUser(ctx context.Context, userID, actorID, afterID int6 ORDER BY id DESC LIMIT @limit */ - return db.WithContext(ctx). + return s.WithContext(ctx). Where("user_id = ?", userID). - Where(db. + Where(s. // Not apply when afterID is not given Where("?", afterID <= 0). Or("id < ?", afterID), ). - Where(db. + Where(s. // Not apply when in not profile page or the user is viewing own profile Where("?", !isProfile || actorID == userID). Or("is_private = ? AND act_user_id = ?", false, userID), @@ -157,14 +157,14 @@ func (db *actions) listByUser(ctx context.Context, userID, actorID, afterID int6 Order("id DESC") } -func (db *actions) ListByUser(ctx context.Context, userID, actorID, afterID int64, isProfile bool) ([]*Action, error) { +func (s *actionsStore) ListByUser(ctx context.Context, userID, actorID, afterID int64, isProfile bool) ([]*Action, error) { actions := make([]*Action, 0, conf.UI.User.NewsFeedPagingNum) - return actions, db.listByUser(ctx, userID, actorID, afterID, isProfile).Find(&actions).Error + return actions, s.listByUser(ctx, userID, actorID, afterID, isProfile).Find(&actions).Error } // notifyWatchers creates rows in action table for watchers who are able to see the action. -func (db *actions) notifyWatchers(ctx context.Context, act *Action) error { - watches, err := NewReposStore(db.DB).ListWatches(ctx, act.RepoID) +func (s *actionsStore) notifyWatchers(ctx context.Context, act *Action) error { + watches, err := NewReposStore(s.DB).ListWatches(ctx, act.RepoID) if err != nil { return errors.Wrap(err, "list watches") } @@ -187,16 +187,16 @@ func (db *actions) notifyWatchers(ctx context.Context, act *Action) error { actions = append(actions, clone(watch.UserID)) } - return db.Create(actions).Error + return s.Create(actions).Error } -func (db *actions) NewRepo(ctx context.Context, doer, owner *User, repo *Repository) error { +func (s *actionsStore) NewRepo(ctx context.Context, doer, owner *User, repo *Repository) error { opType := ActionCreateRepo if repo.IsFork { opType = ActionForkRepo } - return db.notifyWatchers(ctx, + return s.notifyWatchers(ctx, &Action{ ActUserID: doer.ID, ActUserName: doer.Name, @@ -209,8 +209,8 @@ func (db *actions) NewRepo(ctx context.Context, doer, owner *User, repo *Reposit ) } -func (db *actions) RenameRepo(ctx context.Context, doer, owner *User, oldRepoName string, repo *Repository) error { - return db.notifyWatchers(ctx, +func (s *actionsStore) RenameRepo(ctx context.Context, doer, owner *User, oldRepoName string, repo *Repository) error { + return s.notifyWatchers(ctx, &Action{ ActUserID: doer.ID, ActUserName: doer.Name, @@ -224,8 +224,8 @@ func (db *actions) RenameRepo(ctx context.Context, doer, owner *User, oldRepoNam ) } -func (db *actions) mirrorSyncAction(ctx context.Context, opType ActionType, owner *User, repo *Repository, refName string, content []byte) error { - return db.notifyWatchers(ctx, +func (s *actionsStore) mirrorSyncAction(ctx context.Context, opType ActionType, owner *User, repo *Repository, refName string, content []byte) error { + return s.notifyWatchers(ctx, &Action{ ActUserID: owner.ID, ActUserName: owner.Name, @@ -249,13 +249,13 @@ type MirrorSyncPushOptions struct { Commits *PushCommits } -func (db *actions) MirrorSyncPush(ctx context.Context, opts MirrorSyncPushOptions) error { +func (s *actionsStore) MirrorSyncPush(ctx context.Context, opts MirrorSyncPushOptions) error { if conf.UI.FeedMaxCommitNum > 0 && len(opts.Commits.Commits) > conf.UI.FeedMaxCommitNum { opts.Commits.Commits = opts.Commits.Commits[:conf.UI.FeedMaxCommitNum] } apiCommits, err := opts.Commits.APIFormat(ctx, - NewUsersStore(db.DB), + NewUsersStore(s.DB), repoutil.RepositoryPath(opts.Owner.Name, opts.Repo.Name), repoutil.HTMLURL(opts.Owner.Name, opts.Repo.Name), ) @@ -288,19 +288,19 @@ func (db *actions) MirrorSyncPush(ctx context.Context, opts MirrorSyncPushOption return errors.Wrap(err, "marshal JSON") } - return db.mirrorSyncAction(ctx, ActionMirrorSyncPush, opts.Owner, opts.Repo, opts.RefName, data) + return s.mirrorSyncAction(ctx, ActionMirrorSyncPush, opts.Owner, opts.Repo, opts.RefName, data) } -func (db *actions) MirrorSyncCreate(ctx context.Context, owner *User, repo *Repository, refName string) error { - return db.mirrorSyncAction(ctx, ActionMirrorSyncCreate, owner, repo, refName, nil) +func (s *actionsStore) MirrorSyncCreate(ctx context.Context, owner *User, repo *Repository, refName string) error { + return s.mirrorSyncAction(ctx, ActionMirrorSyncCreate, owner, repo, refName, nil) } -func (db *actions) MirrorSyncDelete(ctx context.Context, owner *User, repo *Repository, refName string) error { - return db.mirrorSyncAction(ctx, ActionMirrorSyncDelete, owner, repo, refName, nil) +func (s *actionsStore) MirrorSyncDelete(ctx context.Context, owner *User, repo *Repository, refName string) error { + return s.mirrorSyncAction(ctx, ActionMirrorSyncDelete, owner, repo, refName, nil) } -func (db *actions) MergePullRequest(ctx context.Context, doer, owner *User, repo *Repository, pull *Issue) error { - return db.notifyWatchers(ctx, +func (s *actionsStore) MergePullRequest(ctx context.Context, doer, owner *User, repo *Repository, pull *Issue) error { + return s.notifyWatchers(ctx, &Action{ ActUserID: doer.ID, ActUserName: doer.Name, @@ -314,8 +314,8 @@ func (db *actions) MergePullRequest(ctx context.Context, doer, owner *User, repo ) } -func (db *actions) TransferRepo(ctx context.Context, doer, oldOwner, newOwner *User, repo *Repository) error { - return db.notifyWatchers(ctx, +func (s *actionsStore) TransferRepo(ctx context.Context, doer, oldOwner, newOwner *User, repo *Repository) error { + return s.notifyWatchers(ctx, &Action{ ActUserID: doer.ID, ActUserName: doer.Name, @@ -487,13 +487,13 @@ type CommitRepoOptions struct { Commits *PushCommits } -func (db *actions) CommitRepo(ctx context.Context, opts CommitRepoOptions) error { - err := NewReposStore(db.DB).Touch(ctx, opts.Repo.ID) +func (s *actionsStore) CommitRepo(ctx context.Context, opts CommitRepoOptions) error { + err := NewReposStore(s.DB).Touch(ctx, opts.Repo.ID) if err != nil { return errors.Wrap(err, "touch repository") } - pusher, err := NewUsersStore(db.DB).GetByUsername(ctx, opts.PusherName) + pusher, err := NewUsersStore(s.DB).GetByUsername(ctx, opts.PusherName) if err != nil { return errors.Wrapf(err, "get pusher [name: %s]", opts.PusherName) } @@ -536,7 +536,7 @@ func (db *actions) CommitRepo(ctx context.Context, opts CommitRepoOptions) error } action.OpType = ActionDeleteBranch - err = db.notifyWatchers(ctx, action) + err = s.notifyWatchers(ctx, action) if err != nil { return errors.Wrap(err, "notify watchers") } @@ -580,7 +580,7 @@ func (db *actions) CommitRepo(ctx context.Context, opts CommitRepoOptions) error } action.OpType = ActionCreateBranch - err = db.notifyWatchers(ctx, action) + err = s.notifyWatchers(ctx, action) if err != nil { return errors.Wrap(err, "notify watchers") } @@ -589,7 +589,7 @@ func (db *actions) CommitRepo(ctx context.Context, opts CommitRepoOptions) error } commits, err := opts.Commits.APIFormat(ctx, - NewUsersStore(db.DB), + NewUsersStore(s.DB), repoutil.RepositoryPath(opts.Owner.Name, opts.Repo.Name), repoutil.HTMLURL(opts.Owner.Name, opts.Repo.Name), ) @@ -616,7 +616,7 @@ func (db *actions) CommitRepo(ctx context.Context, opts CommitRepoOptions) error } action.OpType = ActionCommitRepo - err = db.notifyWatchers(ctx, action) + err = s.notifyWatchers(ctx, action) if err != nil { return errors.Wrap(err, "notify watchers") } @@ -631,13 +631,13 @@ type PushTagOptions struct { NewCommitID string } -func (db *actions) PushTag(ctx context.Context, opts PushTagOptions) error { - err := NewReposStore(db.DB).Touch(ctx, opts.Repo.ID) +func (s *actionsStore) PushTag(ctx context.Context, opts PushTagOptions) error { + err := NewReposStore(s.DB).Touch(ctx, opts.Repo.ID) if err != nil { return errors.Wrap(err, "touch repository") } - pusher, err := NewUsersStore(db.DB).GetByUsername(ctx, opts.PusherName) + pusher, err := NewUsersStore(s.DB).GetByUsername(ctx, opts.PusherName) if err != nil { return errors.Wrapf(err, "get pusher [name: %s]", opts.PusherName) } @@ -672,7 +672,7 @@ func (db *actions) PushTag(ctx context.Context, opts PushTagOptions) error { } action.OpType = ActionDeleteTag - err = db.notifyWatchers(ctx, action) + err = s.notifyWatchers(ctx, action) if err != nil { return errors.Wrap(err, "notify watchers") } @@ -696,7 +696,7 @@ func (db *actions) PushTag(ctx context.Context, opts PushTagOptions) error { } action.OpType = ActionPushTag - err = db.notifyWatchers(ctx, action) + err = s.notifyWatchers(ctx, action) if err != nil { return errors.Wrap(err, "notify watchers") } diff --git a/internal/database/actions_test.go b/internal/database/actions_test.go index 81b87d637..5fd3f7580 100644 --- a/internal/database/actions_test.go +++ b/internal/database/actions_test.go @@ -99,13 +99,13 @@ func TestActions(t *testing.T) { ctx := context.Background() t.Parallel() - db := &actions{ - DB: newTestDB(t, "actions"), + db := &actionsStore{ + DB: newTestDB(t, "actionsStore"), } for _, tc := range []struct { name string - test func(t *testing.T, ctx context.Context, db *actions) + test func(t *testing.T, ctx context.Context, db *actionsStore) }{ {"CommitRepo", actionsCommitRepo}, {"ListByOrganization", actionsListByOrganization}, @@ -132,7 +132,7 @@ func TestActions(t *testing.T) { } } -func actionsCommitRepo(t *testing.T, ctx context.Context, db *actions) { +func actionsCommitRepo(t *testing.T, ctx context.Context, db *actionsStore) { alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) repo, err := NewReposStore(db.DB).Create(ctx, @@ -324,7 +324,7 @@ func actionsCommitRepo(t *testing.T, ctx context.Context, db *actions) { }) } -func actionsListByOrganization(t *testing.T, ctx context.Context, db *actions) { +func actionsListByOrganization(t *testing.T, ctx context.Context, db *actionsStore) { if os.Getenv("GOGS_DATABASE_TYPE") != "postgres" { t.Skip("Skipping testing with not using PostgreSQL") return @@ -363,14 +363,14 @@ func actionsListByOrganization(t *testing.T, ctx context.Context, db *actions) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { got := db.DB.ToSQL(func(tx *gorm.DB) *gorm.DB { - return NewActionsStore(tx).(*actions).listByOrganization(ctx, test.orgID, test.actorID, test.afterID).Find(new(Action)) + return NewActionsStore(tx).(*actionsStore).listByOrganization(ctx, test.orgID, test.actorID, test.afterID).Find(new(Action)) }) assert.Equal(t, test.want, got) }) } } -func actionsListByUser(t *testing.T, ctx context.Context, db *actions) { +func actionsListByUser(t *testing.T, ctx context.Context, db *actionsStore) { if os.Getenv("GOGS_DATABASE_TYPE") != "postgres" { t.Skip("Skipping testing with not using PostgreSQL") return @@ -428,14 +428,14 @@ func actionsListByUser(t *testing.T, ctx context.Context, db *actions) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { got := db.DB.ToSQL(func(tx *gorm.DB) *gorm.DB { - return NewActionsStore(tx).(*actions).listByUser(ctx, test.userID, test.actorID, test.afterID, test.isProfile).Find(new(Action)) + return NewActionsStore(tx).(*actionsStore).listByUser(ctx, test.userID, test.actorID, test.afterID, test.isProfile).Find(new(Action)) }) assert.Equal(t, test.want, got) }) } } -func actionsMergePullRequest(t *testing.T, ctx context.Context, db *actions) { +func actionsMergePullRequest(t *testing.T, ctx context.Context, db *actionsStore) { alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) repo, err := NewReposStore(db.DB).Create(ctx, @@ -480,7 +480,7 @@ func actionsMergePullRequest(t *testing.T, ctx context.Context, db *actions) { assert.Equal(t, want, got) } -func actionsMirrorSyncCreate(t *testing.T, ctx context.Context, db *actions) { +func actionsMirrorSyncCreate(t *testing.T, ctx context.Context, db *actionsStore) { alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) repo, err := NewReposStore(db.DB).Create(ctx, @@ -521,7 +521,7 @@ func actionsMirrorSyncCreate(t *testing.T, ctx context.Context, db *actions) { assert.Equal(t, want, got) } -func actionsMirrorSyncDelete(t *testing.T, ctx context.Context, db *actions) { +func actionsMirrorSyncDelete(t *testing.T, ctx context.Context, db *actionsStore) { alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) repo, err := NewReposStore(db.DB).Create(ctx, @@ -562,7 +562,7 @@ func actionsMirrorSyncDelete(t *testing.T, ctx context.Context, db *actions) { assert.Equal(t, want, got) } -func actionsMirrorSyncPush(t *testing.T, ctx context.Context, db *actions) { +func actionsMirrorSyncPush(t *testing.T, ctx context.Context, db *actionsStore) { alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) repo, err := NewReposStore(db.DB).Create(ctx, @@ -627,7 +627,7 @@ func actionsMirrorSyncPush(t *testing.T, ctx context.Context, db *actions) { assert.Equal(t, want, got) } -func actionsNewRepo(t *testing.T, ctx context.Context, db *actions) { +func actionsNewRepo(t *testing.T, ctx context.Context, db *actionsStore) { alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) repo, err := NewReposStore(db.DB).Create(ctx, @@ -702,7 +702,7 @@ func actionsNewRepo(t *testing.T, ctx context.Context, db *actions) { }) } -func actionsPushTag(t *testing.T, ctx context.Context, db *actions) { +func actionsPushTag(t *testing.T, ctx context.Context, db *actionsStore) { // NOTE: We set a noop mock here to avoid data race with other tests that writes // to the mock server because this function holds a lock. conf.SetMockServer(t, conf.ServerOpts{}) @@ -798,7 +798,7 @@ func actionsPushTag(t *testing.T, ctx context.Context, db *actions) { }) } -func actionsRenameRepo(t *testing.T, ctx context.Context, db *actions) { +func actionsRenameRepo(t *testing.T, ctx context.Context, db *actionsStore) { alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) repo, err := NewReposStore(db.DB).Create(ctx, @@ -835,7 +835,7 @@ func actionsRenameRepo(t *testing.T, ctx context.Context, db *actions) { assert.Equal(t, want, got) } -func actionsTransferRepo(t *testing.T, ctx context.Context, db *actions) { +func actionsTransferRepo(t *testing.T, ctx context.Context, db *actionsStore) { alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) bob, err := NewUsersStore(db.DB).Create(ctx, "bob", "bob@example.com", CreateUserOptions{}) diff --git a/internal/database/database.go b/internal/database/database.go index a296b3346..f0b800b41 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -123,15 +123,15 @@ func Init(w logger.Writer) (*gorm.DB, error) { } // Initialize stores, sorted in alphabetical order. - AccessTokens = &accessTokens{DB: db} + AccessTokens = &accessTokensStore{DB: db} Actions = NewActionsStore(db) - LoginSources = &loginSources{DB: db, files: sourceFiles} - LFS = &lfs{DB: db} + LoginSources = &loginSourcesStore{DB: db, files: sourceFiles} + LFS = &lfsStore{DB: db} Notices = NewNoticesStore(db) Orgs = NewOrgsStore(db) Perms = NewPermsStore(db) Repos = NewReposStore(db) - TwoFactors = &twoFactors{DB: db} + TwoFactors = &twoFactorsStore{DB: db} Users = NewUsersStore(db) return db, nil diff --git a/internal/database/lfs.go b/internal/database/lfs.go index 67ec63aa2..cf5616fd6 100644 --- a/internal/database/lfs.go +++ b/internal/database/lfs.go @@ -38,20 +38,20 @@ type LFSObject struct { CreatedAt time.Time `gorm:"not null"` } -var _ LFSStore = (*lfs)(nil) +var _ LFSStore = (*lfsStore)(nil) -type lfs struct { +type lfsStore struct { *gorm.DB } -func (db *lfs) CreateObject(ctx context.Context, repoID int64, oid lfsutil.OID, size int64, storage lfsutil.Storage) error { +func (s *lfsStore) CreateObject(ctx context.Context, repoID int64, oid lfsutil.OID, size int64, storage lfsutil.Storage) error { object := &LFSObject{ RepoID: repoID, OID: oid, Size: size, Storage: storage, } - return db.WithContext(ctx).Create(object).Error + return s.WithContext(ctx).Create(object).Error } type ErrLFSObjectNotExist struct { @@ -71,9 +71,9 @@ func (ErrLFSObjectNotExist) NotFound() bool { return true } -func (db *lfs) GetObjectByOID(ctx context.Context, repoID int64, oid lfsutil.OID) (*LFSObject, error) { +func (s *lfsStore) GetObjectByOID(ctx context.Context, repoID int64, oid lfsutil.OID) (*LFSObject, error) { object := new(LFSObject) - err := db.WithContext(ctx).Where("repo_id = ? AND oid = ?", repoID, oid).First(object).Error + err := s.WithContext(ctx).Where("repo_id = ? AND oid = ?", repoID, oid).First(object).Error if err != nil { if err == gorm.ErrRecordNotFound { return nil, ErrLFSObjectNotExist{args: errutil.Args{"repoID": repoID, "oid": oid}} @@ -83,13 +83,13 @@ func (db *lfs) GetObjectByOID(ctx context.Context, repoID int64, oid lfsutil.OID return object, err } -func (db *lfs) GetObjectsByOIDs(ctx context.Context, repoID int64, oids ...lfsutil.OID) ([]*LFSObject, error) { +func (s *lfsStore) GetObjectsByOIDs(ctx context.Context, repoID int64, oids ...lfsutil.OID) ([]*LFSObject, error) { if len(oids) == 0 { return []*LFSObject{}, nil } objects := make([]*LFSObject, 0, len(oids)) - err := db.WithContext(ctx).Where("repo_id = ? AND oid IN (?)", repoID, oids).Find(&objects).Error + err := s.WithContext(ctx).Where("repo_id = ? AND oid IN (?)", repoID, oids).Find(&objects).Error if err != nil && err != gorm.ErrRecordNotFound { return nil, err } diff --git a/internal/database/lfs_test.go b/internal/database/lfs_test.go index 57966d116..cfff2fe31 100644 --- a/internal/database/lfs_test.go +++ b/internal/database/lfs_test.go @@ -23,13 +23,13 @@ func TestLFS(t *testing.T) { t.Parallel() ctx := context.Background() - db := &lfs{ - DB: newTestDB(t, "lfs"), + db := &lfsStore{ + DB: newTestDB(t, "lfsStore"), } for _, tc := range []struct { name string - test func(t *testing.T, ctx context.Context, db *lfs) + test func(t *testing.T, ctx context.Context, db *lfsStore) }{ {"CreateObject", lfsCreateObject}, {"GetObjectByOID", lfsGetObjectByOID}, @@ -48,7 +48,7 @@ func TestLFS(t *testing.T) { } } -func lfsCreateObject(t *testing.T, ctx context.Context, db *lfs) { +func lfsCreateObject(t *testing.T, ctx context.Context, db *lfsStore) { // Create first LFS object repoID := int64(1) oid := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f") @@ -65,7 +65,7 @@ func lfsCreateObject(t *testing.T, ctx context.Context, db *lfs) { assert.Error(t, err) } -func lfsGetObjectByOID(t *testing.T, ctx context.Context, db *lfs) { +func lfsGetObjectByOID(t *testing.T, ctx context.Context, db *lfsStore) { // Create a LFS object repoID := int64(1) oid := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f") @@ -82,7 +82,7 @@ func lfsGetObjectByOID(t *testing.T, ctx context.Context, db *lfs) { assert.Equal(t, expErr, err) } -func lfsGetObjectsByOIDs(t *testing.T, ctx context.Context, db *lfs) { +func lfsGetObjectsByOIDs(t *testing.T, ctx context.Context, db *lfsStore) { // Create two LFS objects repoID := int64(1) oid1 := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f") diff --git a/internal/database/login_sources.go b/internal/database/login_sources.go index 6f5846295..d35234813 100644 --- a/internal/database/login_sources.go +++ b/internal/database/login_sources.go @@ -180,9 +180,9 @@ func (s *LoginSource) GitHub() *github.Config { return s.Provider.Config().(*github.Config) } -var _ LoginSourcesStore = (*loginSources)(nil) +var _ LoginSourcesStore = (*loginSourcesStore)(nil) -type loginSources struct { +type loginSourcesStore struct { *gorm.DB files loginSourceFilesStore } @@ -208,8 +208,8 @@ func (err ErrLoginSourceAlreadyExist) Error() string { return fmt.Sprintf("login source already exists: %v", err.args) } -func (db *loginSources) Create(ctx context.Context, opts CreateLoginSourceOptions) (*LoginSource, error) { - err := db.WithContext(ctx).Where("name = ?", opts.Name).First(new(LoginSource)).Error +func (s *loginSourcesStore) Create(ctx context.Context, opts CreateLoginSourceOptions) (*LoginSource, error) { + err := s.WithContext(ctx).Where("name = ?", opts.Name).First(new(LoginSource)).Error if err == nil { return nil, ErrLoginSourceAlreadyExist{args: errutil.Args{"name": opts.Name}} } else if err != gorm.ErrRecordNotFound { @@ -226,13 +226,13 @@ func (db *loginSources) Create(ctx context.Context, opts CreateLoginSourceOption if err != nil { return nil, err } - return source, db.WithContext(ctx).Create(source).Error + return source, s.WithContext(ctx).Create(source).Error } -func (db *loginSources) Count(ctx context.Context) int64 { +func (s *loginSourcesStore) Count(ctx context.Context) int64 { var count int64 - db.WithContext(ctx).Model(new(LoginSource)).Count(&count) - return count + int64(db.files.Len()) + s.WithContext(ctx).Model(new(LoginSource)).Count(&count) + return count + int64(s.files.Len()) } type ErrLoginSourceInUse struct { @@ -248,24 +248,24 @@ func (err ErrLoginSourceInUse) Error() string { return fmt.Sprintf("login source is still used by some users: %v", err.args) } -func (db *loginSources) DeleteByID(ctx context.Context, id int64) error { +func (s *loginSourcesStore) DeleteByID(ctx context.Context, id int64) error { var count int64 - err := db.WithContext(ctx).Model(new(User)).Where("login_source = ?", id).Count(&count).Error + err := s.WithContext(ctx).Model(new(User)).Where("login_source = ?", id).Count(&count).Error if err != nil { return err } else if count > 0 { return ErrLoginSourceInUse{args: errutil.Args{"id": id}} } - return db.WithContext(ctx).Where("id = ?", id).Delete(new(LoginSource)).Error + return s.WithContext(ctx).Where("id = ?", id).Delete(new(LoginSource)).Error } -func (db *loginSources) GetByID(ctx context.Context, id int64) (*LoginSource, error) { +func (s *loginSourcesStore) GetByID(ctx context.Context, id int64) (*LoginSource, error) { source := new(LoginSource) - err := db.WithContext(ctx).Where("id = ?", id).First(source).Error + err := s.WithContext(ctx).Where("id = ?", id).First(source).Error if err != nil { if err == gorm.ErrRecordNotFound { - return db.files.GetByID(id) + return s.files.GetByID(id) } return nil, err } @@ -277,9 +277,9 @@ type ListLoginSourceOptions struct { OnlyActivated bool } -func (db *loginSources) List(ctx context.Context, opts ListLoginSourceOptions) ([]*LoginSource, error) { +func (s *loginSourcesStore) List(ctx context.Context, opts ListLoginSourceOptions) ([]*LoginSource, error) { var sources []*LoginSource - query := db.WithContext(ctx).Order("id ASC") + query := s.WithContext(ctx).Order("id ASC") if opts.OnlyActivated { query = query.Where("is_actived = ?", true) } @@ -288,11 +288,11 @@ func (db *loginSources) List(ctx context.Context, opts ListLoginSourceOptions) ( return nil, err } - return append(sources, db.files.List(opts)...), nil + return append(sources, s.files.List(opts)...), nil } -func (db *loginSources) ResetNonDefault(ctx context.Context, dflt *LoginSource) error { - err := db.WithContext(ctx). +func (s *loginSourcesStore) ResetNonDefault(ctx context.Context, dflt *LoginSource) error { + err := s.WithContext(ctx). Model(new(LoginSource)). Where("id != ?", dflt.ID). Updates(map[string]any{"is_default": false}). @@ -301,7 +301,7 @@ func (db *loginSources) ResetNonDefault(ctx context.Context, dflt *LoginSource) return err } - for _, source := range db.files.List(ListLoginSourceOptions{}) { + for _, source := range s.files.List(ListLoginSourceOptions{}) { if source.File != nil && source.ID != dflt.ID { source.File.SetGeneral("is_default", "false") if err = source.File.Save(); err != nil { @@ -310,13 +310,13 @@ func (db *loginSources) ResetNonDefault(ctx context.Context, dflt *LoginSource) } } - db.files.Update(dflt) + s.files.Update(dflt) return nil } -func (db *loginSources) Save(ctx context.Context, source *LoginSource) error { +func (s *loginSourcesStore) Save(ctx context.Context, source *LoginSource) error { if source.File == nil { - return db.WithContext(ctx).Save(source).Error + return s.WithContext(ctx).Save(source).Error } source.File.SetGeneral("name", source.Name) diff --git a/internal/database/login_sources_test.go b/internal/database/login_sources_test.go index f4299c975..167ed90db 100644 --- a/internal/database/login_sources_test.go +++ b/internal/database/login_sources_test.go @@ -163,13 +163,13 @@ func TestLoginSources(t *testing.T) { t.Parallel() ctx := context.Background() - db := &loginSources{ - DB: newTestDB(t, "loginSources"), + db := &loginSourcesStore{ + DB: newTestDB(t, "loginSourcesStore"), } for _, tc := range []struct { name string - test func(t *testing.T, ctx context.Context, db *loginSources) + test func(t *testing.T, ctx context.Context, db *loginSourcesStore) }{ {"Create", loginSourcesCreate}, {"Count", loginSourcesCount}, @@ -192,7 +192,7 @@ func TestLoginSources(t *testing.T) { } } -func loginSourcesCreate(t *testing.T, ctx context.Context, db *loginSources) { +func loginSourcesCreate(t *testing.T, ctx context.Context, db *loginSourcesStore) { // Create first login source with name "GitHub" source, err := db.Create(ctx, CreateLoginSourceOptions{ @@ -219,7 +219,7 @@ func loginSourcesCreate(t *testing.T, ctx context.Context, db *loginSources) { assert.Equal(t, wantErr, err) } -func loginSourcesCount(t *testing.T, ctx context.Context, db *loginSources) { +func loginSourcesCount(t *testing.T, ctx context.Context, db *loginSourcesStore) { // Create two login sources, one in database and one as source file. _, err := db.Create(ctx, CreateLoginSourceOptions{ @@ -241,7 +241,7 @@ func loginSourcesCount(t *testing.T, ctx context.Context, db *loginSources) { assert.Equal(t, int64(3), db.Count(ctx)) } -func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSources) { +func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSourcesStore) { t.Run("delete but in used", func(t *testing.T) { source, err := db.Create(ctx, CreateLoginSourceOptions{ @@ -257,7 +257,7 @@ func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSources) require.NoError(t, err) // Create a user that uses this login source - _, err = (&users{DB: db.DB}).Create(ctx, "alice", "", + _, err = (&usersStore{DB: db.DB}).Create(ctx, "alice", "", CreateUserOptions{ LoginSource: source.ID, }, @@ -308,7 +308,7 @@ func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSources) assert.Equal(t, wantErr, err) } -func loginSourcesGetByID(t *testing.T, ctx context.Context, db *loginSources) { +func loginSourcesGetByID(t *testing.T, ctx context.Context, db *loginSourcesStore) { mock := NewMockLoginSourceFilesStore() mock.GetByIDFunc.SetDefaultHook(func(id int64) (*LoginSource, error) { if id != 101 { @@ -344,7 +344,7 @@ func loginSourcesGetByID(t *testing.T, ctx context.Context, db *loginSources) { require.NoError(t, err) } -func loginSourcesList(t *testing.T, ctx context.Context, db *loginSources) { +func loginSourcesList(t *testing.T, ctx context.Context, db *loginSourcesStore) { mock := NewMockLoginSourceFilesStore() mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource { if opts.OnlyActivated { @@ -393,7 +393,7 @@ func loginSourcesList(t *testing.T, ctx context.Context, db *loginSources) { assert.Equal(t, 2, len(sources), "number of sources") } -func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSources) { +func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSourcesStore) { mock := NewMockLoginSourceFilesStore() mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource { mockFile := NewMockLoginSourceFileStore() @@ -448,7 +448,7 @@ func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSou assert.False(t, source2.IsDefault) } -func loginSourcesSave(t *testing.T, ctx context.Context, db *loginSources) { +func loginSourcesSave(t *testing.T, ctx context.Context, db *loginSourcesStore) { t.Run("save to database", func(t *testing.T) { // Create a login source with name "GitHub" source, err := db.Create(ctx, diff --git a/internal/database/mocks.go b/internal/database/mocks.go index 8d1ffdbc5..40c2cfc62 100644 --- a/internal/database/mocks.go +++ b/internal/database/mocks.go @@ -32,7 +32,7 @@ func setMockLoginSourcesStore(t *testing.T, mock LoginSourcesStore) { }) } -func setMockLoginSourceFilesStore(t *testing.T, db *loginSources, mock loginSourceFilesStore) { +func setMockLoginSourceFilesStore(t *testing.T, db *loginSourcesStore, mock loginSourceFilesStore) { before := db.files db.files = mock t.Cleanup(func() { diff --git a/internal/database/notices.go b/internal/database/notices.go index 37d6777fa..ebcb5dd91 100644 --- a/internal/database/notices.go +++ b/internal/database/notices.go @@ -32,20 +32,20 @@ type NoticesStore interface { var Notices NoticesStore -var _ NoticesStore = (*notices)(nil) +var _ NoticesStore = (*noticesStore)(nil) -type notices struct { +type noticesStore struct { *gorm.DB } // NewNoticesStore returns a persistent interface for system notices with given // database connection. func NewNoticesStore(db *gorm.DB) NoticesStore { - return ¬ices{DB: db} + return ¬icesStore{DB: db} } -func (db *notices) Create(ctx context.Context, typ NoticeType, desc string) error { - return db.WithContext(ctx).Create( +func (s *noticesStore) Create(ctx context.Context, typ NoticeType, desc string) error { + return s.WithContext(ctx).Create( &Notice{ Type: typ, Description: desc, @@ -53,26 +53,26 @@ func (db *notices) Create(ctx context.Context, typ NoticeType, desc string) erro ).Error } -func (db *notices) DeleteByIDs(ctx context.Context, ids ...int64) error { - return db.WithContext(ctx).Where("id IN (?)", ids).Delete(&Notice{}).Error +func (s *noticesStore) DeleteByIDs(ctx context.Context, ids ...int64) error { + return s.WithContext(ctx).Where("id IN (?)", ids).Delete(&Notice{}).Error } -func (db *notices) DeleteAll(ctx context.Context) error { - return db.WithContext(ctx).Where("TRUE").Delete(&Notice{}).Error +func (s *noticesStore) DeleteAll(ctx context.Context) error { + return s.WithContext(ctx).Where("TRUE").Delete(&Notice{}).Error } -func (db *notices) List(ctx context.Context, page, pageSize int) ([]*Notice, error) { +func (s *noticesStore) List(ctx context.Context, page, pageSize int) ([]*Notice, error) { notices := make([]*Notice, 0, pageSize) - return notices, db.WithContext(ctx). + return notices, s.WithContext(ctx). Limit(pageSize).Offset((page - 1) * pageSize). Order("id DESC"). Find(¬ices). Error } -func (db *notices) Count(ctx context.Context) int64 { +func (s *noticesStore) Count(ctx context.Context) int64 { var count int64 - db.WithContext(ctx).Model(&Notice{}).Count(&count) + s.WithContext(ctx).Model(&Notice{}).Count(&count) return count } diff --git a/internal/database/notices_test.go b/internal/database/notices_test.go index 3c7efcc90..56d631f3b 100644 --- a/internal/database/notices_test.go +++ b/internal/database/notices_test.go @@ -65,13 +65,13 @@ func TestNotices(t *testing.T) { t.Parallel() ctx := context.Background() - db := ¬ices{ - DB: newTestDB(t, "notices"), + db := ¬icesStore{ + DB: newTestDB(t, "noticesStore"), } for _, tc := range []struct { name string - test func(t *testing.T, ctx context.Context, db *notices) + test func(t *testing.T, ctx context.Context, db *noticesStore) }{ {"Create", noticesCreate}, {"DeleteByIDs", noticesDeleteByIDs}, @@ -92,7 +92,7 @@ func TestNotices(t *testing.T) { } } -func noticesCreate(t *testing.T, ctx context.Context, db *notices) { +func noticesCreate(t *testing.T, ctx context.Context, db *noticesStore) { err := db.Create(ctx, NoticeTypeRepository, "test") require.NoError(t, err) @@ -100,7 +100,7 @@ func noticesCreate(t *testing.T, ctx context.Context, db *notices) { assert.Equal(t, int64(1), count) } -func noticesDeleteByIDs(t *testing.T, ctx context.Context, db *notices) { +func noticesDeleteByIDs(t *testing.T, ctx context.Context, db *noticesStore) { err := db.Create(ctx, NoticeTypeRepository, "test") require.NoError(t, err) @@ -120,7 +120,7 @@ func noticesDeleteByIDs(t *testing.T, ctx context.Context, db *notices) { assert.Equal(t, int64(0), count) } -func noticesDeleteAll(t *testing.T, ctx context.Context, db *notices) { +func noticesDeleteAll(t *testing.T, ctx context.Context, db *noticesStore) { err := db.Create(ctx, NoticeTypeRepository, "test") require.NoError(t, err) @@ -131,7 +131,7 @@ func noticesDeleteAll(t *testing.T, ctx context.Context, db *notices) { assert.Equal(t, int64(0), count) } -func noticesList(t *testing.T, ctx context.Context, db *notices) { +func noticesList(t *testing.T, ctx context.Context, db *noticesStore) { err := db.Create(ctx, NoticeTypeRepository, "test 1") require.NoError(t, err) err = db.Create(ctx, NoticeTypeRepository, "test 2") @@ -151,7 +151,7 @@ func noticesList(t *testing.T, ctx context.Context, db *notices) { require.Len(t, got, 2) } -func noticesCount(t *testing.T, ctx context.Context, db *notices) { +func noticesCount(t *testing.T, ctx context.Context, db *noticesStore) { count := db.Count(ctx) assert.Equal(t, int64(0), count) diff --git a/internal/database/orgs.go b/internal/database/orgs.go index 233c39d54..e63ac94c8 100644 --- a/internal/database/orgs.go +++ b/internal/database/orgs.go @@ -30,16 +30,16 @@ type OrgsStore interface { var Orgs OrgsStore -var _ OrgsStore = (*orgs)(nil) +var _ OrgsStore = (*orgsStore)(nil) -type orgs struct { +type orgsStore struct { *gorm.DB } // NewOrgsStore returns a persistent interface for orgs with given database // connection. func NewOrgsStore(db *gorm.DB) OrgsStore { - return &orgs{DB: db} + return &orgsStore{DB: db} } type ListOrgsOptions struct { @@ -49,7 +49,7 @@ type ListOrgsOptions struct { IncludePrivateMembers bool } -func (db *orgs) List(ctx context.Context, opts ListOrgsOptions) ([]*Organization, error) { +func (s *orgsStore) List(ctx context.Context, opts ListOrgsOptions) ([]*Organization, error) { if opts.MemberID <= 0 { return nil, errors.New("MemberID must be greater than 0") } @@ -64,7 +64,7 @@ func (db *orgs) List(ctx context.Context, opts ListOrgsOptions) ([]*Organization [AND org_user.is_public = @includePrivateMembers] ORDER BY org.id ASC */ - tx := db.WithContext(ctx). + tx := s.WithContext(ctx). Joins(dbutil.Quote("JOIN org_user ON org_user.org_id = %s.id", "user")). Where("org_user.uid = ?", opts.MemberID). Order(dbutil.Quote("%s.id ASC", "user")) @@ -76,13 +76,13 @@ func (db *orgs) List(ctx context.Context, opts ListOrgsOptions) ([]*Organization return orgs, tx.Find(&orgs).Error } -func (db *orgs) SearchByName(ctx context.Context, keyword string, page, pageSize int, orderBy string) ([]*Organization, int64, error) { - return searchUserByName(ctx, db.DB, UserTypeOrganization, keyword, page, pageSize, orderBy) +func (s *orgsStore) SearchByName(ctx context.Context, keyword string, page, pageSize int, orderBy string) ([]*Organization, int64, error) { + return searchUserByName(ctx, s.DB, UserTypeOrganization, keyword, page, pageSize, orderBy) } -func (db *orgs) CountByUser(ctx context.Context, userID int64) (int64, error) { +func (s *orgsStore) CountByUser(ctx context.Context, userID int64) (int64, error) { var count int64 - return count, db.WithContext(ctx).Model(&OrgUser{}).Where("uid = ?", userID).Count(&count).Error + return count, s.WithContext(ctx).Model(&OrgUser{}).Where("uid = ?", userID).Count(&count).Error } type Organization = User diff --git a/internal/database/orgs_test.go b/internal/database/orgs_test.go index 188f5605c..abe868516 100644 --- a/internal/database/orgs_test.go +++ b/internal/database/orgs_test.go @@ -21,13 +21,13 @@ func TestOrgs(t *testing.T) { t.Parallel() ctx := context.Background() - db := &orgs{ - DB: newTestDB(t, "orgs"), + db := &orgsStore{ + DB: newTestDB(t, "orgsStore"), } for _, tc := range []struct { name string - test func(t *testing.T, ctx context.Context, db *orgs) + test func(t *testing.T, ctx context.Context, db *orgsStore) }{ {"List", orgsList}, {"SearchByName", orgsSearchByName}, @@ -46,7 +46,7 @@ func TestOrgs(t *testing.T) { } } -func orgsList(t *testing.T, ctx context.Context, db *orgs) { +func orgsList(t *testing.T, ctx context.Context, db *orgsStore) { usersStore := NewUsersStore(db.DB) alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) @@ -116,7 +116,7 @@ func orgsList(t *testing.T, ctx context.Context, db *orgs) { } } -func orgsSearchByName(t *testing.T, ctx context.Context, db *orgs) { +func orgsSearchByName(t *testing.T, ctx context.Context, db *orgsStore) { // TODO: Use Orgs.Create to replace SQL hack when the method is available. usersStore := NewUsersStore(db.DB) org1, err := usersStore.Create(ctx, "org1", "org1@example.com", CreateUserOptions{FullName: "Acme Corp"}) @@ -161,7 +161,7 @@ func orgsSearchByName(t *testing.T, ctx context.Context, db *orgs) { }) } -func orgsCountByUser(t *testing.T, ctx context.Context, db *orgs) { +func orgsCountByUser(t *testing.T, ctx context.Context, db *orgsStore) { // TODO: Use Orgs.Join to replace SQL hack when the method is available. err := db.Exec(`INSERT INTO org_user (uid, org_id) VALUES (?, ?)`, 1, 1).Error require.NoError(t, err) diff --git a/internal/database/perms.go b/internal/database/perms.go index bf2c2692c..ec6662704 100644 --- a/internal/database/perms.go +++ b/internal/database/perms.go @@ -74,16 +74,16 @@ func ParseAccessMode(permission string) AccessMode { } } -var _ PermsStore = (*perms)(nil) +var _ PermsStore = (*permsStore)(nil) -type perms struct { +type permsStore struct { *gorm.DB } // NewPermsStore returns a persistent interface for permissions with given // database connection. func NewPermsStore(db *gorm.DB) PermsStore { - return &perms{DB: db} + return &permsStore{DB: db} } type AccessModeOptions struct { @@ -91,7 +91,7 @@ type AccessModeOptions struct { Private bool // Whether the repository is private. } -func (db *perms) AccessMode(ctx context.Context, userID, repoID int64, opts AccessModeOptions) (mode AccessMode) { +func (s *permsStore) AccessMode(ctx context.Context, userID, repoID int64, opts AccessModeOptions) (mode AccessMode) { if repoID <= 0 { return AccessModeNone } @@ -111,7 +111,7 @@ func (db *perms) AccessMode(ctx context.Context, userID, repoID int64, opts Acce } access := new(Access) - err := db.WithContext(ctx).Where("user_id = ? AND repo_id = ?", userID, repoID).First(access).Error + err := s.WithContext(ctx).Where("user_id = ? AND repo_id = ?", userID, repoID).First(access).Error if err != nil { if err != gorm.ErrRecordNotFound { log.Error("Failed to get access [user_id: %d, repo_id: %d]: %v", userID, repoID, err) @@ -121,11 +121,11 @@ func (db *perms) AccessMode(ctx context.Context, userID, repoID int64, opts Acce return access.Mode } -func (db *perms) Authorize(ctx context.Context, userID, repoID int64, desired AccessMode, opts AccessModeOptions) bool { - return desired <= db.AccessMode(ctx, userID, repoID, opts) +func (s *permsStore) Authorize(ctx context.Context, userID, repoID int64, desired AccessMode, opts AccessModeOptions) bool { + return desired <= s.AccessMode(ctx, userID, repoID, opts) } -func (db *perms) SetRepoPerms(ctx context.Context, repoID int64, accessMap map[int64]AccessMode) error { +func (s *permsStore) SetRepoPerms(ctx context.Context, repoID int64, accessMap map[int64]AccessMode) error { records := make([]*Access, 0, len(accessMap)) for userID, mode := range accessMap { records = append(records, &Access{ @@ -135,7 +135,7 @@ func (db *perms) SetRepoPerms(ctx context.Context, repoID int64, accessMap map[i }) } - return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error { err := tx.Where("repo_id = ?", repoID).Delete(new(Access)).Error if err != nil { return err diff --git a/internal/database/perms_test.go b/internal/database/perms_test.go index 4a4951330..9304b71f2 100644 --- a/internal/database/perms_test.go +++ b/internal/database/perms_test.go @@ -19,13 +19,13 @@ func TestPerms(t *testing.T) { t.Parallel() ctx := context.Background() - db := &perms{ - DB: newTestDB(t, "perms"), + db := &permsStore{ + DB: newTestDB(t, "permsStore"), } for _, tc := range []struct { name string - test func(t *testing.T, ctx context.Context, db *perms) + test func(t *testing.T, ctx context.Context, db *permsStore) }{ {"AccessMode", permsAccessMode}, {"Authorize", permsAuthorize}, @@ -44,7 +44,7 @@ func TestPerms(t *testing.T) { } } -func permsAccessMode(t *testing.T, ctx context.Context, db *perms) { +func permsAccessMode(t *testing.T, ctx context.Context, db *permsStore) { // Set up permissions err := db.SetRepoPerms(ctx, 1, map[int64]AccessMode{ @@ -155,7 +155,7 @@ func permsAccessMode(t *testing.T, ctx context.Context, db *perms) { } } -func permsAuthorize(t *testing.T, ctx context.Context, db *perms) { +func permsAuthorize(t *testing.T, ctx context.Context, db *permsStore) { // Set up permissions err := db.SetRepoPerms(ctx, 1, map[int64]AccessMode{ @@ -241,7 +241,7 @@ func permsAuthorize(t *testing.T, ctx context.Context, db *perms) { } } -func permsSetRepoPerms(t *testing.T, ctx context.Context, db *perms) { +func permsSetRepoPerms(t *testing.T, ctx context.Context, db *permsStore) { for _, update := range []struct { repoID int64 accessMap map[int64]AccessMode diff --git a/internal/database/public_keys.go b/internal/database/public_keys.go index 8cdf64410..3856bd353 100644 --- a/internal/database/public_keys.go +++ b/internal/database/public_keys.go @@ -24,23 +24,23 @@ type PublicKeysStore interface { var PublicKeys PublicKeysStore -var _ PublicKeysStore = (*publicKeys)(nil) +var _ PublicKeysStore = (*publicKeysStore)(nil) -type publicKeys struct { +type publicKeysStore struct { *gorm.DB } // NewPublicKeysStore returns a persistent interface for public keys with given // database connection. func NewPublicKeysStore(db *gorm.DB) PublicKeysStore { - return &publicKeys{DB: db} + return &publicKeysStore{DB: db} } func authorizedKeysPath() string { return filepath.Join(conf.SSH.RootPath, "authorized_keys") } -func (db *publicKeys) RewriteAuthorizedKeys() error { +func (s *publicKeysStore) RewriteAuthorizedKeys() error { sshOpLocker.Lock() defer sshOpLocker.Unlock() @@ -61,7 +61,7 @@ func (db *publicKeys) RewriteAuthorizedKeys() error { // NOTE: More recently updated keys are more likely to be used more frequently, // putting them in the earlier lines could speed up the key lookup by SSHD. - rows, err := db.Model(&PublicKey{}).Order("updated_unix DESC").Rows() + rows, err := s.Model(&PublicKey{}).Order("updated_unix DESC").Rows() if err != nil { return errors.Wrap(err, "iterate public keys") } @@ -69,7 +69,7 @@ func (db *publicKeys) RewriteAuthorizedKeys() error { for rows.Next() { var key PublicKey - err = db.ScanRows(rows, &key) + err = s.ScanRows(rows, &key) if err != nil { return errors.Wrap(err, "scan rows") } diff --git a/internal/database/public_keys_test.go b/internal/database/public_keys_test.go index bae252a9d..9a361be81 100644 --- a/internal/database/public_keys_test.go +++ b/internal/database/public_keys_test.go @@ -24,13 +24,13 @@ func TestPublicKeys(t *testing.T) { t.Parallel() ctx := context.Background() - db := &publicKeys{ - DB: newTestDB(t, "publicKeys"), + db := &publicKeysStore{ + DB: newTestDB(t, "publicKeysStore"), } for _, tc := range []struct { name string - test func(t *testing.T, ctx context.Context, db *publicKeys) + test func(t *testing.T, ctx context.Context, db *publicKeysStore) }{ {"RewriteAuthorizedKeys", publicKeysRewriteAuthorizedKeys}, } { @@ -47,7 +47,7 @@ func TestPublicKeys(t *testing.T) { } } -func publicKeysRewriteAuthorizedKeys(t *testing.T, ctx context.Context, db *publicKeys) { +func publicKeysRewriteAuthorizedKeys(t *testing.T, ctx context.Context, db *publicKeysStore) { // TODO: Use PublicKeys.Add to replace SQL hack when the method is available. publicKey := &PublicKey{ OwnerID: 1, diff --git a/internal/database/repos.go b/internal/database/repos.go index 17c46d330..731c2edd4 100644 --- a/internal/database/repos.go +++ b/internal/database/repos.go @@ -119,16 +119,16 @@ func (r *Repository) APIFormat(owner *User, opts ...RepositoryAPIFormatOptions) } } -var _ ReposStore = (*repos)(nil) +var _ ReposStore = (*reposStore)(nil) -type repos struct { +type reposStore struct { *gorm.DB } // NewReposStore returns a persistent interface for repositories with given // database connection. func NewReposStore(db *gorm.DB) ReposStore { - return &repos{DB: db} + return &reposStore{DB: db} } type ErrRepoAlreadyExist struct { @@ -157,13 +157,13 @@ type CreateRepoOptions struct { ForkID int64 } -func (db *repos) Create(ctx context.Context, ownerID int64, opts CreateRepoOptions) (*Repository, error) { +func (s *reposStore) Create(ctx context.Context, ownerID int64, opts CreateRepoOptions) (*Repository, error) { err := isRepoNameAllowed(opts.Name) if err != nil { return nil, err } - _, err = db.GetByName(ctx, ownerID, opts.Name) + _, err = s.GetByName(ctx, ownerID, opts.Name) if err == nil { return nil, ErrRepoAlreadyExist{ args: errutil.Args{ @@ -189,7 +189,7 @@ func (db *repos) Create(ctx context.Context, ownerID int64, opts CreateRepoOptio IsFork: opts.Fork, ForkID: opts.ForkID, } - return repo, db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return repo, s.WithContext(ctx).Transaction(func(tx *gorm.DB) error { err = tx.Create(repo).Error if err != nil { return errors.Wrap(err, "create") @@ -203,7 +203,7 @@ func (db *repos) Create(ctx context.Context, ownerID int64, opts CreateRepoOptio }) } -func (db *repos) GetByCollaboratorID(ctx context.Context, collaboratorID int64, limit int, orderBy string) ([]*Repository, error) { +func (s *reposStore) GetByCollaboratorID(ctx context.Context, collaboratorID int64, limit int, orderBy string) ([]*Repository, error) { /* Equivalent SQL for PostgreSQL: @@ -214,7 +214,7 @@ func (db *repos) GetByCollaboratorID(ctx context.Context, collaboratorID int64, LIMIT @limit */ var repos []*Repository - return repos, db.WithContext(ctx). + return repos, s.WithContext(ctx). Joins("JOIN access ON access.repo_id = repository.id AND access.user_id = ?", collaboratorID). Where("access.mode >= ?", AccessModeRead). Order(orderBy). @@ -223,7 +223,7 @@ func (db *repos) GetByCollaboratorID(ctx context.Context, collaboratorID int64, Error } -func (db *repos) GetByCollaboratorIDWithAccessMode(ctx context.Context, collaboratorID int64) (map[*Repository]AccessMode, error) { +func (s *reposStore) GetByCollaboratorIDWithAccessMode(ctx context.Context, collaboratorID int64) (map[*Repository]AccessMode, error) { /* Equivalent SQL for PostgreSQL: @@ -238,7 +238,7 @@ func (db *repos) GetByCollaboratorIDWithAccessMode(ctx context.Context, collabor *Repository Mode AccessMode } - err := db.WithContext(ctx). + err := s.WithContext(ctx). Select("repository.*", "access.mode"). Table("repository"). Joins("JOIN access ON access.repo_id = repository.id AND access.user_id = ?", collaboratorID). @@ -275,9 +275,9 @@ func (ErrRepoNotExist) NotFound() bool { return true } -func (db *repos) GetByID(ctx context.Context, id int64) (*Repository, error) { +func (s *reposStore) GetByID(ctx context.Context, id int64) (*Repository, error) { repo := new(Repository) - err := db.WithContext(ctx).Where("id = ?", id).First(repo).Error + err := s.WithContext(ctx).Where("id = ?", id).First(repo).Error if err != nil { if err == gorm.ErrRecordNotFound { return nil, ErrRepoNotExist{errutil.Args{"repoID": id}} @@ -287,9 +287,9 @@ func (db *repos) GetByID(ctx context.Context, id int64) (*Repository, error) { return repo, nil } -func (db *repos) GetByName(ctx context.Context, ownerID int64, name string) (*Repository, error) { +func (s *reposStore) GetByName(ctx context.Context, ownerID int64, name string) (*Repository, error) { repo := new(Repository) - err := db.WithContext(ctx). + err := s.WithContext(ctx). Where("owner_id = ? AND lower_name = ?", ownerID, strings.ToLower(name)). First(repo). Error @@ -307,7 +307,7 @@ func (db *repos) GetByName(ctx context.Context, ownerID int64, name string) (*Re return repo, nil } -func (db *repos) recountStars(tx *gorm.DB, userID, repoID int64) error { +func (s *reposStore) recountStars(tx *gorm.DB, userID, repoID int64) error { /* Equivalent SQL for PostgreSQL: @@ -350,40 +350,40 @@ func (db *repos) recountStars(tx *gorm.DB, userID, repoID int64) error { return nil } -func (db *repos) Star(ctx context.Context, userID, repoID int64) error { - return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - s := &Star{ +func (s *reposStore) Star(ctx context.Context, userID, repoID int64) error { + return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + star := &Star{ UserID: userID, RepoID: repoID, } - result := tx.FirstOrCreate(s, s) + result := tx.FirstOrCreate(star, star) if result.Error != nil { return errors.Wrap(result.Error, "upsert") } else if result.RowsAffected <= 0 { return nil // Relation already exists } - return db.recountStars(tx, userID, repoID) + return s.recountStars(tx, userID, repoID) }) } -func (db *repos) Touch(ctx context.Context, id int64) error { - return db.WithContext(ctx). +func (s *reposStore) Touch(ctx context.Context, id int64) error { + return s.WithContext(ctx). Model(new(Repository)). Where("id = ?", id). Updates(map[string]any{ "is_bare": false, - "updated_unix": db.NowFunc().Unix(), + "updated_unix": s.NowFunc().Unix(), }). Error } -func (db *repos) ListWatches(ctx context.Context, repoID int64) ([]*Watch, error) { +func (s *reposStore) ListWatches(ctx context.Context, repoID int64) ([]*Watch, error) { var watches []*Watch - return watches, db.WithContext(ctx).Where("repo_id = ?", repoID).Find(&watches).Error + return watches, s.WithContext(ctx).Where("repo_id = ?", repoID).Find(&watches).Error } -func (db *repos) recountWatches(tx *gorm.DB, repoID int64) error { +func (s *reposStore) recountWatches(tx *gorm.DB, repoID int64) error { /* Equivalent SQL for PostgreSQL: @@ -402,8 +402,8 @@ func (db *repos) recountWatches(tx *gorm.DB, repoID int64) error { Error } -func (db *repos) Watch(ctx context.Context, userID, repoID int64) error { - return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { +func (s *reposStore) Watch(ctx context.Context, userID, repoID int64) error { + return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error { w := &Watch{ UserID: userID, RepoID: repoID, @@ -415,12 +415,12 @@ func (db *repos) Watch(ctx context.Context, userID, repoID int64) error { return nil // Relation already exists } - return db.recountWatches(tx, repoID) + return s.recountWatches(tx, repoID) }) } -func (db *repos) HasForkedBy(ctx context.Context, repoID, userID int64) bool { +func (s *reposStore) HasForkedBy(ctx context.Context, repoID, userID int64) bool { var count int64 - db.WithContext(ctx).Model(new(Repository)).Where("owner_id = ? AND fork_id = ?", userID, repoID).Count(&count) + s.WithContext(ctx).Model(new(Repository)).Where("owner_id = ? AND fork_id = ?", userID, repoID).Count(&count) return count > 0 } diff --git a/internal/database/repos_test.go b/internal/database/repos_test.go index 27b62fe3d..f9cba6ea9 100644 --- a/internal/database/repos_test.go +++ b/internal/database/repos_test.go @@ -85,13 +85,13 @@ func TestRepos(t *testing.T) { t.Parallel() ctx := context.Background() - db := &repos{ + db := &reposStore{ DB: newTestDB(t, "repos"), } for _, tc := range []struct { name string - test func(t *testing.T, ctx context.Context, db *repos) + test func(t *testing.T, ctx context.Context, db *reposStore) }{ {"Create", reposCreate}, {"GetByCollaboratorID", reposGetByCollaboratorID}, @@ -117,7 +117,7 @@ func TestRepos(t *testing.T) { } } -func reposCreate(t *testing.T, ctx context.Context, db *repos) { +func reposCreate(t *testing.T, ctx context.Context, db *reposStore) { t.Run("name not allowed", func(t *testing.T) { _, err := db.Create(ctx, 1, @@ -159,7 +159,7 @@ func reposCreate(t *testing.T, ctx context.Context, db *repos) { assert.Equal(t, 1, repo.NumWatches) // The owner is watching the repo by default. } -func reposGetByCollaboratorID(t *testing.T, ctx context.Context, db *repos) { +func reposGetByCollaboratorID(t *testing.T, ctx context.Context, db *reposStore) { repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"}) require.NoError(t, err) repo2, err := db.Create(ctx, 2, CreateRepoOptions{Name: "repo2"}) @@ -185,7 +185,7 @@ func reposGetByCollaboratorID(t *testing.T, ctx context.Context, db *repos) { }) } -func reposGetByCollaboratorIDWithAccessMode(t *testing.T, ctx context.Context, db *repos) { +func reposGetByCollaboratorIDWithAccessMode(t *testing.T, ctx context.Context, db *reposStore) { repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"}) require.NoError(t, err) repo2, err := db.Create(ctx, 2, CreateRepoOptions{Name: "repo2"}) @@ -213,7 +213,7 @@ func reposGetByCollaboratorIDWithAccessMode(t *testing.T, ctx context.Context, d assert.Equal(t, AccessModeAdmin, accessModes[repo2.ID]) } -func reposGetByID(t *testing.T, ctx context.Context, db *repos) { +func reposGetByID(t *testing.T, ctx context.Context, db *reposStore) { repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"}) require.NoError(t, err) @@ -226,7 +226,7 @@ func reposGetByID(t *testing.T, ctx context.Context, db *repos) { assert.Equal(t, wantErr, err) } -func reposGetByName(t *testing.T, ctx context.Context, db *repos) { +func reposGetByName(t *testing.T, ctx context.Context, db *reposStore) { repo, err := db.Create(ctx, 1, CreateRepoOptions{ Name: "repo1", @@ -242,7 +242,7 @@ func reposGetByName(t *testing.T, ctx context.Context, db *repos) { assert.Equal(t, wantErr, err) } -func reposStar(t *testing.T, ctx context.Context, db *repos) { +func reposStar(t *testing.T, ctx context.Context, db *reposStore) { repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"}) require.NoError(t, err) usersStore := NewUsersStore(db.DB) @@ -261,7 +261,7 @@ func reposStar(t *testing.T, ctx context.Context, db *repos) { assert.Equal(t, 1, alice.NumStars) } -func reposTouch(t *testing.T, ctx context.Context, db *repos) { +func reposTouch(t *testing.T, ctx context.Context, db *reposStore) { repo, err := db.Create(ctx, 1, CreateRepoOptions{ Name: "repo1", @@ -287,7 +287,7 @@ func reposTouch(t *testing.T, ctx context.Context, db *repos) { assert.False(t, got.IsBare) } -func reposListWatches(t *testing.T, ctx context.Context, db *repos) { +func reposListWatches(t *testing.T, ctx context.Context, db *reposStore) { err := db.Watch(ctx, 1, 1) require.NoError(t, err) err = db.Watch(ctx, 2, 1) @@ -308,7 +308,7 @@ func reposListWatches(t *testing.T, ctx context.Context, db *repos) { assert.Equal(t, want, got) } -func reposWatch(t *testing.T, ctx context.Context, db *repos) { +func reposWatch(t *testing.T, ctx context.Context, db *reposStore) { reposStore := NewReposStore(db.DB) repo1, err := reposStore.Create(ctx, 1, CreateRepoOptions{Name: "repo1"}) require.NoError(t, err) @@ -325,7 +325,7 @@ func reposWatch(t *testing.T, ctx context.Context, db *repos) { assert.Equal(t, 2, repo1.NumWatches) // The owner is watching the repo by default. } -func reposHasForkedBy(t *testing.T, ctx context.Context, db *repos) { +func reposHasForkedBy(t *testing.T, ctx context.Context, db *reposStore) { has := db.HasForkedBy(ctx, 1, 2) assert.False(t, has) diff --git a/internal/database/two_factors.go b/internal/database/two_factors.go index 33f2b49ca..2f2941d8b 100644 --- a/internal/database/two_factors.go +++ b/internal/database/two_factors.go @@ -50,13 +50,13 @@ func (t *TwoFactor) AfterFind(_ *gorm.DB) error { return nil } -var _ TwoFactorsStore = (*twoFactors)(nil) +var _ TwoFactorsStore = (*twoFactorsStore)(nil) -type twoFactors struct { +type twoFactorsStore struct { *gorm.DB } -func (db *twoFactors) Create(ctx context.Context, userID int64, key, secret string) error { +func (s *twoFactorsStore) Create(ctx context.Context, userID int64, key, secret string) error { encrypted, err := cryptoutil.AESGCMEncrypt(cryptoutil.MD5Bytes(key), []byte(secret)) if err != nil { return errors.Wrap(err, "encrypt secret") @@ -71,7 +71,7 @@ func (db *twoFactors) Create(ctx context.Context, userID int64, key, secret stri return errors.Wrap(err, "generate recovery codes") } - return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error { err := tx.Create(tf).Error if err != nil { return err @@ -100,9 +100,9 @@ func (ErrTwoFactorNotFound) NotFound() bool { return true } -func (db *twoFactors) GetByUserID(ctx context.Context, userID int64) (*TwoFactor, error) { +func (s *twoFactorsStore) GetByUserID(ctx context.Context, userID int64) (*TwoFactor, error) { tf := new(TwoFactor) - err := db.WithContext(ctx).Where("user_id = ?", userID).First(tf).Error + err := s.WithContext(ctx).Where("user_id = ?", userID).First(tf).Error if err != nil { if err == gorm.ErrRecordNotFound { return nil, ErrTwoFactorNotFound{args: errutil.Args{"userID": userID}} @@ -112,9 +112,9 @@ func (db *twoFactors) GetByUserID(ctx context.Context, userID int64) (*TwoFactor return tf, nil } -func (db *twoFactors) IsEnabled(ctx context.Context, userID int64) bool { +func (s *twoFactorsStore) IsEnabled(ctx context.Context, userID int64) bool { var count int64 - err := db.WithContext(ctx).Model(new(TwoFactor)).Where("user_id = ?", userID).Count(&count).Error + err := s.WithContext(ctx).Model(new(TwoFactor)).Where("user_id = ?", userID).Count(&count).Error if err != nil { log.Error("Failed to count two factors [user_id: %d]: %v", userID, err) } diff --git a/internal/database/two_factors_test.go b/internal/database/two_factors_test.go index f4f30e022..af18055d6 100644 --- a/internal/database/two_factors_test.go +++ b/internal/database/two_factors_test.go @@ -67,13 +67,13 @@ func TestTwoFactors(t *testing.T) { t.Parallel() ctx := context.Background() - db := &twoFactors{ - DB: newTestDB(t, "twoFactors"), + db := &twoFactorsStore{ + DB: newTestDB(t, "twoFactorsStore"), } for _, tc := range []struct { name string - test func(t *testing.T, ctx context.Context, db *twoFactors) + test func(t *testing.T, ctx context.Context, db *twoFactorsStore) }{ {"Create", twoFactorsCreate}, {"GetByUserID", twoFactorsGetByUserID}, @@ -92,7 +92,7 @@ func TestTwoFactors(t *testing.T) { } } -func twoFactorsCreate(t *testing.T, ctx context.Context, db *twoFactors) { +func twoFactorsCreate(t *testing.T, ctx context.Context, db *twoFactorsStore) { // Create a 2FA token err := db.Create(ctx, 1, "secure-key", "secure-secret") require.NoError(t, err) @@ -109,7 +109,7 @@ func twoFactorsCreate(t *testing.T, ctx context.Context, db *twoFactors) { assert.Equal(t, int64(10), count) } -func twoFactorsGetByUserID(t *testing.T, ctx context.Context, db *twoFactors) { +func twoFactorsGetByUserID(t *testing.T, ctx context.Context, db *twoFactorsStore) { // Create a 2FA token for user 1 err := db.Create(ctx, 1, "secure-key", "secure-secret") require.NoError(t, err) @@ -124,7 +124,7 @@ func twoFactorsGetByUserID(t *testing.T, ctx context.Context, db *twoFactors) { assert.Equal(t, wantErr, err) } -func twoFactorsIsEnabled(t *testing.T, ctx context.Context, db *twoFactors) { +func twoFactorsIsEnabled(t *testing.T, ctx context.Context, db *twoFactorsStore) { // Create a 2FA token for user 1 err := db.Create(ctx, 1, "secure-key", "secure-secret") require.NoError(t, err) diff --git a/internal/database/users.go b/internal/database/users.go index 018a94a4e..c8dba6e1b 100644 --- a/internal/database/users.go +++ b/internal/database/users.go @@ -146,16 +146,16 @@ type UsersStore interface { var Users UsersStore -var _ UsersStore = (*users)(nil) +var _ UsersStore = (*usersStore)(nil) -type users struct { +type usersStore struct { *gorm.DB } // NewUsersStore returns a persistent interface for users with given database // connection. func NewUsersStore(db *gorm.DB) UsersStore { - return &users{DB: db} + return &usersStore{DB: db} } type ErrLoginSourceMismatch struct { @@ -173,10 +173,10 @@ func (err ErrLoginSourceMismatch) Error() string { return fmt.Sprintf("login source mismatch: %v", err.args) } -func (db *users) Authenticate(ctx context.Context, login, password string, loginSourceID int64) (*User, error) { +func (s *usersStore) Authenticate(ctx context.Context, login, password string, loginSourceID int64) (*User, error) { login = strings.ToLower(login) - query := db.WithContext(ctx) + query := s.WithContext(ctx) if strings.Contains(login, "@") { query = query.Where("email = ?", login) } else { @@ -244,7 +244,7 @@ func (db *users) Authenticate(ctx context.Context, login, password string, login return nil, fmt.Errorf("invalid pattern for attribute 'username' [%s]: must be valid alpha or numeric or dash(-_) or dot characters", extAccount.Name) } - return db.Create(ctx, extAccount.Name, extAccount.Email, + return s.Create(ctx, extAccount.Name, extAccount.Email, CreateUserOptions{ FullName: extAccount.FullName, LoginSource: authSourceID, @@ -257,13 +257,13 @@ func (db *users) Authenticate(ctx context.Context, login, password string, login ) } -func (db *users) ChangeUsername(ctx context.Context, userID int64, newUsername string) error { +func (s *usersStore) ChangeUsername(ctx context.Context, userID int64, newUsername string) error { err := isUsernameAllowed(newUsername) if err != nil { return err } - if db.IsUsernameUsed(ctx, newUsername, userID) { + if s.IsUsernameUsed(ctx, newUsername, userID) { return ErrUserAlreadyExist{ args: errutil.Args{ "name": newUsername, @@ -271,12 +271,12 @@ func (db *users) ChangeUsername(ctx context.Context, userID int64, newUsername s } } - user, err := db.GetByID(ctx, userID) + user, err := s.GetByID(ctx, userID) if err != nil { return errors.Wrap(err, "get user") } - return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error { err := tx.Model(&User{}). Where("id = ?", user.ID). Updates(map[string]any{ @@ -338,9 +338,9 @@ func (db *users) ChangeUsername(ctx context.Context, userID int64, newUsername s }) } -func (db *users) Count(ctx context.Context) int64 { +func (s *usersStore) Count(ctx context.Context) int64 { var count int64 - db.WithContext(ctx).Model(&User{}).Where("type = ?", UserTypeIndividual).Count(&count) + s.WithContext(ctx).Model(&User{}).Where("type = ?", UserTypeIndividual).Count(&count) return count } @@ -393,13 +393,13 @@ func (err ErrEmailAlreadyUsed) Error() string { return fmt.Sprintf("email has been used: %v", err.args) } -func (db *users) Create(ctx context.Context, username, email string, opts CreateUserOptions) (*User, error) { +func (s *usersStore) Create(ctx context.Context, username, email string, opts CreateUserOptions) (*User, error) { err := isUsernameAllowed(username) if err != nil { return nil, err } - if db.IsUsernameUsed(ctx, username, 0) { + if s.IsUsernameUsed(ctx, username, 0) { return nil, ErrUserAlreadyExist{ args: errutil.Args{ "name": username, @@ -408,7 +408,7 @@ func (db *users) Create(ctx context.Context, username, email string, opts Create } email = strings.ToLower(strings.TrimSpace(email)) - _, err = db.GetByEmail(ctx, email) + _, err = s.GetByEmail(ctx, email) if err == nil { return nil, ErrEmailAlreadyUsed{ args: errutil.Args{ @@ -446,17 +446,17 @@ func (db *users) Create(ctx context.Context, username, email string, opts Create } user.Password = userutil.EncodePassword(user.Password, user.Salt) - return user, db.WithContext(ctx).Create(user).Error + return user, s.WithContext(ctx).Create(user).Error } -func (db *users) DeleteCustomAvatar(ctx context.Context, userID int64) error { +func (s *usersStore) DeleteCustomAvatar(ctx context.Context, userID int64) error { _ = os.Remove(userutil.CustomAvatarPath(userID)) - return db.WithContext(ctx). + return s.WithContext(ctx). Model(&User{}). Where("id = ?", userID). Updates(map[string]any{ "use_custom_avatar": false, - "updated_unix": db.NowFunc().Unix(), + "updated_unix": s.NowFunc().Unix(), }). Error } @@ -491,8 +491,8 @@ func (err ErrUserHasOrgs) Error() string { return fmt.Sprintf("user still has organization membership: %v", err.args) } -func (db *users) DeleteByID(ctx context.Context, userID int64, skipRewriteAuthorizedKeys bool) error { - user, err := db.GetByID(ctx, userID) +func (s *usersStore) DeleteByID(ctx context.Context, userID int64, skipRewriteAuthorizedKeys bool) error { + user, err := s.GetByID(ctx, userID) if err != nil { if IsErrUserNotExist(err) { return nil @@ -503,14 +503,14 @@ func (db *users) DeleteByID(ctx context.Context, userID int64, skipRewriteAuthor // Double check the user is not a direct owner of any repository and not a // member of any organization. var count int64 - err = db.WithContext(ctx).Model(&Repository{}).Where("owner_id = ?", userID).Count(&count).Error + err = s.WithContext(ctx).Model(&Repository{}).Where("owner_id = ?", userID).Count(&count).Error if err != nil { return errors.Wrap(err, "count repositories") } else if count > 0 { return ErrUserOwnRepos{args: errutil.Args{"userID": userID}} } - err = db.WithContext(ctx).Model(&OrgUser{}).Where("uid = ?", userID).Count(&count).Error + err = s.WithContext(ctx).Model(&OrgUser{}).Where("uid = ?", userID).Count(&count).Error if err != nil { return errors.Wrap(err, "count organization membership") } else if count > 0 { @@ -518,7 +518,7 @@ func (db *users) DeleteByID(ctx context.Context, userID int64, skipRewriteAuthor } needsRewriteAuthorizedKeys := false - err = db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + err = s.WithContext(ctx).Transaction(func(tx *gorm.DB) error { /* Equivalent SQL for PostgreSQL: @@ -645,7 +645,7 @@ func (db *users) DeleteByID(ctx context.Context, userID int64, skipRewriteAuthor _ = os.Remove(userutil.CustomAvatarPath(userID)) if needsRewriteAuthorizedKeys { - err = NewPublicKeysStore(db.DB).RewriteAuthorizedKeys() + err = NewPublicKeysStore(s.DB).RewriteAuthorizedKeys() if err != nil { return errors.Wrap(err, `rewrite "authorized_keys" file`) } @@ -655,15 +655,15 @@ func (db *users) DeleteByID(ctx context.Context, userID int64, skipRewriteAuthor // NOTE: We do not take context.Context here because this operation in practice // could much longer than the general request timeout (e.g. one minute). -func (db *users) DeleteInactivated() error { +func (s *usersStore) DeleteInactivated() error { var userIDs []int64 - err := db.Model(&User{}).Where("is_active = ?", false).Pluck("id", &userIDs).Error + err := s.Model(&User{}).Where("is_active = ?", false).Pluck("id", &userIDs).Error if err != nil { return errors.Wrap(err, "get inactivated user IDs") } for _, userID := range userIDs { - err = db.DeleteByID(context.Background(), userID, true) + err = s.DeleteByID(context.Background(), userID, true) if err != nil { // Skip users that may had set to inactivated by admins. if IsErrUserOwnRepos(err) || IsErrUserHasOrgs(err) { @@ -672,14 +672,14 @@ func (db *users) DeleteInactivated() error { return errors.Wrapf(err, "delete user with ID %d", userID) } } - err = NewPublicKeysStore(db.DB).RewriteAuthorizedKeys() + err = NewPublicKeysStore(s.DB).RewriteAuthorizedKeys() if err != nil { return errors.Wrap(err, `rewrite "authorized_keys" file`) } return nil } -func (*users) recountFollows(tx *gorm.DB, userID, followID int64) error { +func (*usersStore) recountFollows(tx *gorm.DB, userID, followID int64) error { /* Equivalent SQL for PostgreSQL: @@ -722,12 +722,12 @@ func (*users) recountFollows(tx *gorm.DB, userID, followID int64) error { return nil } -func (db *users) Follow(ctx context.Context, userID, followID int64) error { +func (s *usersStore) Follow(ctx context.Context, userID, followID int64) error { if userID == followID { return nil } - return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error { f := &Follow{ UserID: userID, FollowID: followID, @@ -739,26 +739,26 @@ func (db *users) Follow(ctx context.Context, userID, followID int64) error { return nil // Relation already exists } - return db.recountFollows(tx, userID, followID) + return s.recountFollows(tx, userID, followID) }) } -func (db *users) Unfollow(ctx context.Context, userID, followID int64) error { +func (s *usersStore) Unfollow(ctx context.Context, userID, followID int64) error { if userID == followID { return nil } - return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error { err := tx.Where("user_id = ? AND follow_id = ?", userID, followID).Delete(&Follow{}).Error if err != nil { return errors.Wrap(err, "delete") } - return db.recountFollows(tx, userID, followID) + return s.recountFollows(tx, userID, followID) }) } -func (db *users) IsFollowing(ctx context.Context, userID, followID int64) bool { - return db.WithContext(ctx).Where("user_id = ? AND follow_id = ?", userID, followID).First(&Follow{}).Error == nil +func (s *usersStore) IsFollowing(ctx context.Context, userID, followID int64) bool { + return s.WithContext(ctx).Where("user_id = ? AND follow_id = ?", userID, followID).First(&Follow{}).Error == nil } var _ errutil.NotFound = (*ErrUserNotExist)(nil) @@ -782,7 +782,7 @@ func (ErrUserNotExist) NotFound() bool { return true } -func (db *users) GetByEmail(ctx context.Context, email string) (*User, error) { +func (s *usersStore) GetByEmail(ctx context.Context, email string) (*User, error) { if email == "" { return nil, ErrUserNotExist{args: errutil.Args{"email": email}} } @@ -801,10 +801,10 @@ func (db *users) GetByEmail(ctx context.Context, email string) (*User, error) { ) */ user := new(User) - err := db.WithContext(ctx). + err := s.WithContext(ctx). Joins(dbutil.Quote("LEFT JOIN email_address ON email_address.uid = %s.id", "user"), true). Where(dbutil.Quote("%s.type = ?", "user"), UserTypeIndividual). - Where(db. + Where(s. Where(dbutil.Quote("%[1]s.email = ? AND %[1]s.is_active = ?", "user"), email, true). Or("email_address.email = ? AND email_address.is_activated = ?", email, true), ). @@ -819,9 +819,9 @@ func (db *users) GetByEmail(ctx context.Context, email string) (*User, error) { return user, nil } -func (db *users) GetByID(ctx context.Context, id int64) (*User, error) { +func (s *usersStore) GetByID(ctx context.Context, id int64) (*User, error) { user := new(User) - err := db.WithContext(ctx).Where("id = ?", id).First(user).Error + err := s.WithContext(ctx).Where("id = ?", id).First(user).Error if err != nil { if err == gorm.ErrRecordNotFound { return nil, ErrUserNotExist{args: errutil.Args{"userID": id}} @@ -831,9 +831,9 @@ func (db *users) GetByID(ctx context.Context, id int64) (*User, error) { return user, nil } -func (db *users) GetByUsername(ctx context.Context, username string) (*User, error) { +func (s *usersStore) GetByUsername(ctx context.Context, username string) (*User, error) { user := new(User) - err := db.WithContext(ctx).Where("lower_name = ?", strings.ToLower(username)).First(user).Error + err := s.WithContext(ctx).Where("lower_name = ?", strings.ToLower(username)).First(user).Error if err != nil { if err == gorm.ErrRecordNotFound { return nil, ErrUserNotExist{args: errutil.Args{"name": username}} @@ -843,9 +843,9 @@ func (db *users) GetByUsername(ctx context.Context, username string) (*User, err return user, nil } -func (db *users) GetByKeyID(ctx context.Context, keyID int64) (*User, error) { +func (s *usersStore) GetByKeyID(ctx context.Context, keyID int64) (*User, error) { user := new(User) - err := db.WithContext(ctx). + err := s.WithContext(ctx). Joins(dbutil.Quote("JOIN public_key ON public_key.owner_id = %s.id", "user")). Where("public_key.id = ?", keyID). First(user). @@ -859,29 +859,29 @@ func (db *users) GetByKeyID(ctx context.Context, keyID int64) (*User, error) { return user, nil } -func (db *users) GetMailableEmailsByUsernames(ctx context.Context, usernames []string) ([]string, error) { +func (s *usersStore) GetMailableEmailsByUsernames(ctx context.Context, usernames []string) ([]string, error) { emails := make([]string, 0, len(usernames)) - return emails, db.WithContext(ctx). + return emails, s.WithContext(ctx). Model(&User{}). Select("email"). Where("lower_name IN (?) AND is_active = ?", usernames, true). Find(&emails).Error } -func (db *users) IsUsernameUsed(ctx context.Context, username string, excludeUserId int64) bool { +func (s *usersStore) IsUsernameUsed(ctx context.Context, username string, excludeUserId int64) bool { if username == "" { return false } - return db.WithContext(ctx). + return s.WithContext(ctx). Select("id"). Where("lower_name = ? AND id != ?", strings.ToLower(username), excludeUserId). First(&User{}). Error != gorm.ErrRecordNotFound } -func (db *users) List(ctx context.Context, page, pageSize int) ([]*User, error) { +func (s *usersStore) List(ctx context.Context, page, pageSize int) ([]*User, error) { users := make([]*User, 0, pageSize) - return users, db.WithContext(ctx). + return users, s.WithContext(ctx). Where("type = ?", UserTypeIndividual). Limit(pageSize).Offset((page - 1) * pageSize). Order("id ASC"). @@ -889,7 +889,7 @@ func (db *users) List(ctx context.Context, page, pageSize int) ([]*User, error) Error } -func (db *users) ListFollowers(ctx context.Context, userID int64, page, pageSize int) ([]*User, error) { +func (s *usersStore) ListFollowers(ctx context.Context, userID int64, page, pageSize int) ([]*User, error) { /* Equivalent SQL for PostgreSQL: @@ -900,7 +900,7 @@ func (db *users) ListFollowers(ctx context.Context, userID int64, page, pageSize LIMIT @limit OFFSET @offset */ users := make([]*User, 0, pageSize) - return users, db.WithContext(ctx). + return users, s.WithContext(ctx). Joins(dbutil.Quote("LEFT JOIN follow ON follow.user_id = %s.id", "user")). Where("follow.follow_id = ?", userID). Limit(pageSize).Offset((page - 1) * pageSize). @@ -909,7 +909,7 @@ func (db *users) ListFollowers(ctx context.Context, userID int64, page, pageSize Error } -func (db *users) ListFollowings(ctx context.Context, userID int64, page, pageSize int) ([]*User, error) { +func (s *usersStore) ListFollowings(ctx context.Context, userID int64, page, pageSize int) ([]*User, error) { /* Equivalent SQL for PostgreSQL: @@ -920,7 +920,7 @@ func (db *users) ListFollowings(ctx context.Context, userID int64, page, pageSiz LIMIT @limit OFFSET @offset */ users := make([]*User, 0, pageSize) - return users, db.WithContext(ctx). + return users, s.WithContext(ctx). Joins(dbutil.Quote("LEFT JOIN follow ON follow.follow_id = %s.id", "user")). Where("follow.user_id = ?", userID). Limit(pageSize).Offset((page - 1) * pageSize). @@ -948,8 +948,8 @@ func searchUserByName(ctx context.Context, db *gorm.DB, userType UserType, keywo return users, count, tx.Order(orderBy).Limit(pageSize).Offset((page - 1) * pageSize).Find(&users).Error } -func (db *users) SearchByName(ctx context.Context, keyword string, page, pageSize int, orderBy string) ([]*User, int64, error) { - return searchUserByName(ctx, db.DB, UserTypeIndividual, keyword, page, pageSize, orderBy) +func (s *usersStore) SearchByName(ctx context.Context, keyword string, page, pageSize int, orderBy string) ([]*User, int64, error) { + return searchUserByName(ctx, s.DB, UserTypeIndividual, keyword, page, pageSize, orderBy) } type UpdateUserOptions struct { @@ -979,9 +979,9 @@ type UpdateUserOptions struct { AvatarEmail *string } -func (db *users) Update(ctx context.Context, userID int64, opts UpdateUserOptions) error { +func (s *usersStore) Update(ctx context.Context, userID int64, opts UpdateUserOptions) error { updates := map[string]any{ - "updated_unix": db.NowFunc().Unix(), + "updated_unix": s.NowFunc().Unix(), } if opts.LoginSource != nil { @@ -1012,7 +1012,7 @@ func (db *users) Update(ctx context.Context, userID int64, opts UpdateUserOption updates["full_name"] = strutil.Truncate(*opts.FullName, 255) } if opts.Email != nil { - _, err := db.GetByEmail(ctx, *opts.Email) + _, err := s.GetByEmail(ctx, *opts.Email) if err == nil { return ErrEmailAlreadyUsed{args: errutil.Args{"email": *opts.Email}} } else if !IsErrUserNotExist(err) { @@ -1063,28 +1063,28 @@ func (db *users) Update(ctx context.Context, userID int64, opts UpdateUserOption updates["avatar_email"] = strutil.Truncate(*opts.AvatarEmail, 255) } - return db.WithContext(ctx).Model(&User{}).Where("id = ?", userID).Updates(updates).Error + return s.WithContext(ctx).Model(&User{}).Where("id = ?", userID).Updates(updates).Error } -func (db *users) UseCustomAvatar(ctx context.Context, userID int64, avatar []byte) error { +func (s *usersStore) UseCustomAvatar(ctx context.Context, userID int64, avatar []byte) error { err := userutil.SaveAvatar(userID, avatar) if err != nil { return errors.Wrap(err, "save avatar") } - return db.WithContext(ctx). + return s.WithContext(ctx). Model(&User{}). Where("id = ?", userID). Updates(map[string]any{ "use_custom_avatar": true, - "updated_unix": db.NowFunc().Unix(), + "updated_unix": s.NowFunc().Unix(), }). Error } -func (db *users) AddEmail(ctx context.Context, userID int64, email string, isActivated bool) error { +func (s *usersStore) AddEmail(ctx context.Context, userID int64, email string, isActivated bool) error { email = strings.ToLower(strings.TrimSpace(email)) - _, err := db.GetByEmail(ctx, email) + _, err := s.GetByEmail(ctx, email) if err == nil { return ErrEmailAlreadyUsed{ args: errutil.Args{ @@ -1095,7 +1095,7 @@ func (db *users) AddEmail(ctx context.Context, userID int64, email string, isAct return errors.Wrap(err, "check user by email") } - return db.WithContext(ctx).Create( + return s.WithContext(ctx).Create( &EmailAddress{ UserID: userID, Email: email, @@ -1125,8 +1125,8 @@ func (ErrEmailNotExist) NotFound() bool { return true } -func (db *users) GetEmail(ctx context.Context, userID int64, email string, needsActivated bool) (*EmailAddress, error) { - tx := db.WithContext(ctx).Where("uid = ? AND email = ?", userID, email) +func (s *usersStore) GetEmail(ctx context.Context, userID int64, email string, needsActivated bool) (*EmailAddress, error) { + tx := s.WithContext(ctx).Where("uid = ? AND email = ?", userID, email) if needsActivated { tx = tx.Where("is_activated = ?", true) } @@ -1146,14 +1146,14 @@ func (db *users) GetEmail(ctx context.Context, userID int64, email string, needs return emailAddress, nil } -func (db *users) ListEmails(ctx context.Context, userID int64) ([]*EmailAddress, error) { - user, err := db.GetByID(ctx, userID) +func (s *usersStore) ListEmails(ctx context.Context, userID int64) ([]*EmailAddress, error) { + user, err := s.GetByID(ctx, userID) if err != nil { return nil, errors.Wrap(err, "get user") } var emails []*EmailAddress - err = db.WithContext(ctx).Where("uid = ?", userID).Order("id ASC").Find(&emails).Error + err = s.WithContext(ctx).Where("uid = ?", userID).Order("id ASC").Find(&emails).Error if err != nil { return nil, errors.Wrap(err, "list emails") } @@ -1179,9 +1179,9 @@ func (db *users) ListEmails(ctx context.Context, userID int64) ([]*EmailAddress, return emails, nil } -func (db *users) MarkEmailActivated(ctx context.Context, userID int64, email string) error { - return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - err := db.WithContext(ctx). +func (s *usersStore) MarkEmailActivated(ctx context.Context, userID int64, email string) error { + return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + err := s.WithContext(ctx). Model(&EmailAddress{}). Where("uid = ? AND email = ?", userID, email). Update("is_activated", true). @@ -1209,9 +1209,9 @@ func (err ErrEmailNotVerified) Error() string { return fmt.Sprintf("email has not been verified: %v", err.args) } -func (db *users) MarkEmailPrimary(ctx context.Context, userID int64, email string) error { +func (s *usersStore) MarkEmailPrimary(ctx context.Context, userID int64, email string) error { var emailAddress EmailAddress - err := db.WithContext(ctx).Where("uid = ? AND email = ?", userID, email).First(&emailAddress).Error + err := s.WithContext(ctx).Where("uid = ? AND email = ?", userID, email).First(&emailAddress).Error if err != nil { if err == gorm.ErrRecordNotFound { return ErrEmailNotExist{args: errutil.Args{"email": email}} @@ -1223,12 +1223,12 @@ func (db *users) MarkEmailPrimary(ctx context.Context, userID int64, email strin return ErrEmailNotVerified{args: errutil.Args{"email": email}} } - user, err := db.GetByID(ctx, userID) + user, err := s.GetByID(ctx, userID) if err != nil { return errors.Wrap(err, "get user") } - return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error { // Make sure the former primary email doesn't disappear. err = tx.FirstOrCreate( &EmailAddress{ @@ -1255,8 +1255,8 @@ func (db *users) MarkEmailPrimary(ctx context.Context, userID int64, email strin }) } -func (db *users) DeleteEmail(ctx context.Context, userID int64, email string) error { - return db.WithContext(ctx).Where("uid = ? AND email = ?", userID, email).Delete(&EmailAddress{}).Error +func (s *usersStore) DeleteEmail(ctx context.Context, userID int64, email string) error { + return s.WithContext(ctx).Where("uid = ? AND email = ?", userID, email).Delete(&EmailAddress{}).Error } // UserType indicates the type of the user account. diff --git a/internal/database/users_test.go b/internal/database/users_test.go index 6b2c3ac13..1225c482d 100644 --- a/internal/database/users_test.go +++ b/internal/database/users_test.go @@ -84,13 +84,13 @@ func TestUsers(t *testing.T) { t.Parallel() ctx := context.Background() - db := &users{ - DB: newTestDB(t, "users"), + db := &usersStore{ + DB: newTestDB(t, "usersStore"), } for _, tc := range []struct { name string - test func(t *testing.T, ctx context.Context, db *users) + test func(t *testing.T, ctx context.Context, db *usersStore) }{ {"Authenticate", usersAuthenticate}, {"ChangeUsername", usersChangeUsername}, @@ -134,7 +134,7 @@ func TestUsers(t *testing.T) { } } -func usersAuthenticate(t *testing.T, ctx context.Context, db *users) { +func usersAuthenticate(t *testing.T, ctx context.Context, db *usersStore) { password := "pa$$word" alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{ @@ -229,7 +229,7 @@ func usersAuthenticate(t *testing.T, ctx context.Context, db *users) { }) } -func usersChangeUsername(t *testing.T, ctx context.Context, db *users) { +func usersChangeUsername(t *testing.T, ctx context.Context, db *usersStore) { alice, err := db.Create( ctx, "alice", @@ -359,7 +359,7 @@ func usersChangeUsername(t *testing.T, ctx context.Context, db *users) { assert.Equal(t, strings.ToUpper(newUsername), alice.Name) } -func usersCount(t *testing.T, ctx context.Context, db *users) { +func usersCount(t *testing.T, ctx context.Context, db *usersStore) { // Has no user initially got := db.Count(ctx) assert.Equal(t, int64(0), got) @@ -382,7 +382,7 @@ func usersCount(t *testing.T, ctx context.Context, db *users) { assert.Equal(t, int64(1), got) } -func usersCreate(t *testing.T, ctx context.Context, db *users) { +func usersCreate(t *testing.T, ctx context.Context, db *usersStore) { alice, err := db.Create( ctx, "alice", @@ -430,7 +430,7 @@ func usersCreate(t *testing.T, ctx context.Context, db *users) { assert.Equal(t, db.NowFunc().Format(time.RFC3339), user.Updated.UTC().Format(time.RFC3339)) } -func usersDeleteCustomAvatar(t *testing.T, ctx context.Context, db *users) { +func usersDeleteCustomAvatar(t *testing.T, ctx context.Context, db *usersStore) { alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) @@ -464,7 +464,7 @@ func usersDeleteCustomAvatar(t *testing.T, ctx context.Context, db *users) { assert.False(t, alice.UseCustomAvatar) } -func usersDeleteByID(t *testing.T, ctx context.Context, db *users) { +func usersDeleteByID(t *testing.T, ctx context.Context, db *usersStore) { reposStore := NewReposStore(db.DB) t.Run("user still has repository ownership", func(t *testing.T) { @@ -674,7 +674,7 @@ func usersDeleteByID(t *testing.T, ctx context.Context, db *users) { assert.Equal(t, wantErr, err) } -func usersDeleteInactivated(t *testing.T, ctx context.Context, db *users) { +func usersDeleteInactivated(t *testing.T, ctx context.Context, db *usersStore) { // User with repository ownership should be skipped alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{}) require.NoError(t, err) @@ -720,7 +720,7 @@ func usersDeleteInactivated(t *testing.T, ctx context.Context, db *users) { require.Len(t, users, 3) } -func usersGetByEmail(t *testing.T, ctx context.Context, db *users) { +func usersGetByEmail(t *testing.T, ctx context.Context, db *usersStore) { t.Run("empty email", func(t *testing.T) { _, err := db.GetByEmail(ctx, "") wantErr := ErrUserNotExist{args: errutil.Args{"email": ""}} @@ -781,7 +781,7 @@ func usersGetByEmail(t *testing.T, ctx context.Context, db *users) { }) } -func usersGetByID(t *testing.T, ctx context.Context, db *users) { +func usersGetByID(t *testing.T, ctx context.Context, db *usersStore) { alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{}) require.NoError(t, err) @@ -794,7 +794,7 @@ func usersGetByID(t *testing.T, ctx context.Context, db *users) { assert.Equal(t, wantErr, err) } -func usersGetByUsername(t *testing.T, ctx context.Context, db *users) { +func usersGetByUsername(t *testing.T, ctx context.Context, db *usersStore) { alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{}) require.NoError(t, err) @@ -807,7 +807,7 @@ func usersGetByUsername(t *testing.T, ctx context.Context, db *users) { assert.Equal(t, wantErr, err) } -func usersGetByKeyID(t *testing.T, ctx context.Context, db *users) { +func usersGetByKeyID(t *testing.T, ctx context.Context, db *usersStore) { alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{}) require.NoError(t, err) @@ -832,7 +832,7 @@ func usersGetByKeyID(t *testing.T, ctx context.Context, db *users) { assert.Equal(t, wantErr, err) } -func usersGetMailableEmailsByUsernames(t *testing.T, ctx context.Context, db *users) { +func usersGetMailableEmailsByUsernames(t *testing.T, ctx context.Context, db *usersStore) { alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{}) require.NoError(t, err) bob, err := db.Create(ctx, "bob", "bob@exmaple.com", CreateUserOptions{Activated: true}) @@ -846,7 +846,7 @@ func usersGetMailableEmailsByUsernames(t *testing.T, ctx context.Context, db *us assert.Equal(t, want, got) } -func usersIsUsernameUsed(t *testing.T, ctx context.Context, db *users) { +func usersIsUsernameUsed(t *testing.T, ctx context.Context, db *usersStore) { alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) @@ -896,7 +896,7 @@ func usersIsUsernameUsed(t *testing.T, ctx context.Context, db *users) { } } -func usersList(t *testing.T, ctx context.Context, db *users) { +func usersList(t *testing.T, ctx context.Context, db *usersStore) { alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) bob, err := db.Create(ctx, "bob", "bob@example.com", CreateUserOptions{}) @@ -929,7 +929,7 @@ func usersList(t *testing.T, ctx context.Context, db *users) { assert.Equal(t, bob.ID, got[1].ID) } -func usersListFollowers(t *testing.T, ctx context.Context, db *users) { +func usersListFollowers(t *testing.T, ctx context.Context, db *usersStore) { john, err := db.Create(ctx, "john", "john@example.com", CreateUserOptions{}) require.NoError(t, err) @@ -960,7 +960,7 @@ func usersListFollowers(t *testing.T, ctx context.Context, db *users) { assert.Equal(t, alice.ID, got[0].ID) } -func usersListFollowings(t *testing.T, ctx context.Context, db *users) { +func usersListFollowings(t *testing.T, ctx context.Context, db *usersStore) { john, err := db.Create(ctx, "john", "john@example.com", CreateUserOptions{}) require.NoError(t, err) @@ -991,7 +991,7 @@ func usersListFollowings(t *testing.T, ctx context.Context, db *users) { assert.Equal(t, alice.ID, got[0].ID) } -func usersSearchByName(t *testing.T, ctx context.Context, db *users) { +func usersSearchByName(t *testing.T, ctx context.Context, db *usersStore) { alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{FullName: "Alice Jordan"}) require.NoError(t, err) bob, err := db.Create(ctx, "bob", "bob@example.com", CreateUserOptions{FullName: "Bob Jordan"}) @@ -1029,7 +1029,7 @@ func usersSearchByName(t *testing.T, ctx context.Context, db *users) { }) } -func usersUpdate(t *testing.T, ctx context.Context, db *users) { +func usersUpdate(t *testing.T, ctx context.Context, db *usersStore) { const oldPassword = "Password" alice, err := db.Create( ctx, @@ -1142,7 +1142,7 @@ func usersUpdate(t *testing.T, ctx context.Context, db *users) { assertValues() } -func usersUseCustomAvatar(t *testing.T, ctx context.Context, db *users) { +func usersUseCustomAvatar(t *testing.T, ctx context.Context, db *usersStore) { alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) @@ -1180,7 +1180,7 @@ func TestIsUsernameAllowed(t *testing.T) { } } -func usersAddEmail(t *testing.T, ctx context.Context, db *users) { +func usersAddEmail(t *testing.T, ctx context.Context, db *usersStore) { t.Run("multiple users can add the same unverified email", func(t *testing.T) { alice, err := db.Create(ctx, "alice", "unverified@example.com", CreateUserOptions{}) require.NoError(t, err) @@ -1197,7 +1197,7 @@ func usersAddEmail(t *testing.T, ctx context.Context, db *users) { }) } -func usersGetEmail(t *testing.T, ctx context.Context, db *users) { +func usersGetEmail(t *testing.T, ctx context.Context, db *usersStore) { const testUserID = 1 const testEmail = "alice@example.com" _, err := db.GetEmail(ctx, testUserID, testEmail, false) @@ -1229,7 +1229,7 @@ func usersGetEmail(t *testing.T, ctx context.Context, db *users) { assert.Equal(t, testEmail, got.Email) } -func usersListEmails(t *testing.T, ctx context.Context, db *users) { +func usersListEmails(t *testing.T, ctx context.Context, db *usersStore) { t.Run("list emails with primary email", func(t *testing.T) { alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) @@ -1265,7 +1265,7 @@ func usersListEmails(t *testing.T, ctx context.Context, db *users) { }) } -func usersMarkEmailActivated(t *testing.T, ctx context.Context, db *users) { +func usersMarkEmailActivated(t *testing.T, ctx context.Context, db *usersStore) { alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) @@ -1283,7 +1283,7 @@ func usersMarkEmailActivated(t *testing.T, ctx context.Context, db *users) { assert.NotEqual(t, alice.Rands, gotAlice.Rands) } -func usersMarkEmailPrimary(t *testing.T, ctx context.Context, db *users) { +func usersMarkEmailPrimary(t *testing.T, ctx context.Context, db *usersStore) { alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) err = db.AddEmail(ctx, alice.ID, "alice2@example.com", false) @@ -1309,7 +1309,7 @@ func usersMarkEmailPrimary(t *testing.T, ctx context.Context, db *users) { assert.False(t, gotEmail.IsActivated) } -func usersDeleteEmail(t *testing.T, ctx context.Context, db *users) { +func usersDeleteEmail(t *testing.T, ctx context.Context, db *usersStore) { alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) @@ -1325,7 +1325,7 @@ func usersDeleteEmail(t *testing.T, ctx context.Context, db *users) { require.Equal(t, want, got) } -func usersFollow(t *testing.T, ctx context.Context, db *users) { +func usersFollow(t *testing.T, ctx context.Context, db *usersStore) { usersStore := NewUsersStore(db.DB) alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) @@ -1348,7 +1348,7 @@ func usersFollow(t *testing.T, ctx context.Context, db *users) { assert.Equal(t, 1, bob.NumFollowers) } -func usersIsFollowing(t *testing.T, ctx context.Context, db *users) { +func usersIsFollowing(t *testing.T, ctx context.Context, db *usersStore) { usersStore := NewUsersStore(db.DB) alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) @@ -1369,7 +1369,7 @@ func usersIsFollowing(t *testing.T, ctx context.Context, db *users) { assert.False(t, got) } -func usersUnfollow(t *testing.T, ctx context.Context, db *users) { +func usersUnfollow(t *testing.T, ctx context.Context, db *usersStore) { usersStore := NewUsersStore(db.DB) alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err)