internal/database: consistently use Store and s as receiver (#7669)

This commit is contained in:
Joe Chen 2024-02-19 20:00:13 -05:00 committed by GitHub
parent dfe27ad556
commit 917c14f2ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 371 additions and 371 deletions

View File

@ -74,9 +74,9 @@ func (t *AccessToken) AfterFind(tx *gorm.DB) error {
return nil return nil
} }
var _ AccessTokensStore = (*accessTokens)(nil) var _ AccessTokensStore = (*accessTokensStore)(nil)
type accessTokens struct { type accessTokensStore struct {
*gorm.DB *gorm.DB
} }
@ -93,8 +93,8 @@ func (err ErrAccessTokenAlreadyExist) Error() string {
return fmt.Sprintf("access token already exists: %v", err.args) return fmt.Sprintf("access token already exists: %v", err.args)
} }
func (db *accessTokens) Create(ctx context.Context, userID int64, name string) (*AccessToken, error) { func (s *accessTokensStore) Create(ctx context.Context, userID int64, name string) (*AccessToken, error) {
err := db.WithContext(ctx).Where("uid = ? AND name = ?", userID, name).First(new(AccessToken)).Error err := s.WithContext(ctx).Where("uid = ? AND name = ?", userID, name).First(new(AccessToken)).Error
if err == nil { if err == nil {
return nil, ErrAccessTokenAlreadyExist{args: errutil.Args{"userID": userID, "name": name}} return nil, ErrAccessTokenAlreadyExist{args: errutil.Args{"userID": userID, "name": name}}
} else if err != gorm.ErrRecordNotFound { } else if err != gorm.ErrRecordNotFound {
@ -110,7 +110,7 @@ func (db *accessTokens) Create(ctx context.Context, userID int64, name string) (
Sha1: sha256[:40], // To pass the column unique constraint, keep the length of SHA1. Sha1: sha256[:40], // To pass the column unique constraint, keep the length of SHA1.
SHA256: sha256, SHA256: sha256,
} }
if err = db.WithContext(ctx).Create(accessToken).Error; err != nil { if err = s.WithContext(ctx).Create(accessToken).Error; err != nil {
return nil, err return nil, err
} }
@ -119,8 +119,8 @@ func (db *accessTokens) Create(ctx context.Context, userID int64, name string) (
return accessToken, nil return accessToken, nil
} }
func (db *accessTokens) DeleteByID(ctx context.Context, userID, id int64) error { func (s *accessTokensStore) DeleteByID(ctx context.Context, userID, id int64) error {
return db.WithContext(ctx).Where("id = ? AND uid = ?", id, userID).Delete(new(AccessToken)).Error return s.WithContext(ctx).Where("id = ? AND uid = ?", id, userID).Delete(new(AccessToken)).Error
} }
var _ errutil.NotFound = (*ErrAccessTokenNotExist)(nil) var _ errutil.NotFound = (*ErrAccessTokenNotExist)(nil)
@ -144,7 +144,7 @@ func (ErrAccessTokenNotExist) NotFound() bool {
return true return true
} }
func (db *accessTokens) GetBySHA1(ctx context.Context, sha1 string) (*AccessToken, error) { func (s *accessTokensStore) GetBySHA1(ctx context.Context, sha1 string) (*AccessToken, error) {
// No need to waste a query for an empty SHA1. // No need to waste a query for an empty SHA1.
if sha1 == "" { if sha1 == "" {
return nil, ErrAccessTokenNotExist{args: errutil.Args{"sha": sha1}} return nil, ErrAccessTokenNotExist{args: errutil.Args{"sha": sha1}}
@ -152,7 +152,7 @@ func (db *accessTokens) GetBySHA1(ctx context.Context, sha1 string) (*AccessToke
sha256 := cryptoutil.SHA256(sha1) sha256 := cryptoutil.SHA256(sha1)
token := new(AccessToken) token := new(AccessToken)
err := db.WithContext(ctx).Where("sha256 = ?", sha256).First(token).Error err := s.WithContext(ctx).Where("sha256 = ?", sha256).First(token).Error
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if err == gorm.ErrRecordNotFound {
return nil, ErrAccessTokenNotExist{args: errutil.Args{"sha": sha1}} return nil, ErrAccessTokenNotExist{args: errutil.Args{"sha": sha1}}
@ -162,15 +162,15 @@ func (db *accessTokens) GetBySHA1(ctx context.Context, sha1 string) (*AccessToke
return token, nil return token, nil
} }
func (db *accessTokens) List(ctx context.Context, userID int64) ([]*AccessToken, error) { func (s *accessTokensStore) List(ctx context.Context, userID int64) ([]*AccessToken, error) {
var tokens []*AccessToken var tokens []*AccessToken
return tokens, db.WithContext(ctx).Where("uid = ?", userID).Order("id ASC").Find(&tokens).Error return tokens, s.WithContext(ctx).Where("uid = ?", userID).Order("id ASC").Find(&tokens).Error
} }
func (db *accessTokens) Touch(ctx context.Context, id int64) error { func (s *accessTokensStore) Touch(ctx context.Context, id int64) error {
return db.WithContext(ctx). return s.WithContext(ctx).
Model(new(AccessToken)). Model(new(AccessToken)).
Where("id = ?", id). Where("id = ?", id).
UpdateColumn("updated_unix", db.NowFunc().Unix()). UpdateColumn("updated_unix", s.NowFunc().Unix()).
Error Error
} }

View File

@ -98,13 +98,13 @@ func TestAccessTokens(t *testing.T) {
t.Parallel() t.Parallel()
ctx := context.Background() ctx := context.Background()
db := &accessTokens{ db := &accessTokensStore{
DB: newTestDB(t, "accessTokens"), DB: newTestDB(t, "accessTokensStore"),
} }
for _, tc := range []struct { for _, tc := range []struct {
name string name string
test func(t *testing.T, ctx context.Context, db *accessTokens) test func(t *testing.T, ctx context.Context, db *accessTokensStore)
}{ }{
{"Create", accessTokensCreate}, {"Create", accessTokensCreate},
{"DeleteByID", accessTokensDeleteByID}, {"DeleteByID", accessTokensDeleteByID},
@ -125,7 +125,7 @@ func TestAccessTokens(t *testing.T) {
} }
} }
func accessTokensCreate(t *testing.T, ctx context.Context, db *accessTokens) { func accessTokensCreate(t *testing.T, ctx context.Context, db *accessTokensStore) {
// Create first access token with name "Test" // Create first access token with name "Test"
token, err := db.Create(ctx, 1, "Test") token, err := db.Create(ctx, 1, "Test")
require.NoError(t, err) require.NoError(t, err)
@ -150,7 +150,7 @@ func accessTokensCreate(t *testing.T, ctx context.Context, db *accessTokens) {
assert.Equal(t, wantErr, err) assert.Equal(t, wantErr, err)
} }
func accessTokensDeleteByID(t *testing.T, ctx context.Context, db *accessTokens) { func accessTokensDeleteByID(t *testing.T, ctx context.Context, db *accessTokensStore) {
// Create an access token with name "Test" // Create an access token with name "Test"
token, err := db.Create(ctx, 1, "Test") token, err := db.Create(ctx, 1, "Test")
require.NoError(t, err) require.NoError(t, err)
@ -177,7 +177,7 @@ func accessTokensDeleteByID(t *testing.T, ctx context.Context, db *accessTokens)
assert.Equal(t, wantErr, err) assert.Equal(t, wantErr, err)
} }
func accessTokensGetBySHA(t *testing.T, ctx context.Context, db *accessTokens) { func accessTokensGetBySHA(t *testing.T, ctx context.Context, db *accessTokensStore) {
// Create an access token with name "Test" // Create an access token with name "Test"
token, err := db.Create(ctx, 1, "Test") token, err := db.Create(ctx, 1, "Test")
require.NoError(t, err) require.NoError(t, err)
@ -196,7 +196,7 @@ func accessTokensGetBySHA(t *testing.T, ctx context.Context, db *accessTokens) {
assert.Equal(t, wantErr, err) assert.Equal(t, wantErr, err)
} }
func accessTokensList(t *testing.T, ctx context.Context, db *accessTokens) { func accessTokensList(t *testing.T, ctx context.Context, db *accessTokensStore) {
// Create two access tokens for user 1 // Create two access tokens for user 1
_, err := db.Create(ctx, 1, "user1_1") _, err := db.Create(ctx, 1, "user1_1")
require.NoError(t, err) require.NoError(t, err)
@ -219,7 +219,7 @@ func accessTokensList(t *testing.T, ctx context.Context, db *accessTokens) {
assert.Equal(t, "user1_2", tokens[1].Name) assert.Equal(t, "user1_2", tokens[1].Name)
} }
func accessTokensTouch(t *testing.T, ctx context.Context, db *accessTokens) { func accessTokensTouch(t *testing.T, ctx context.Context, db *accessTokensStore) {
// Create an access token with name "Test" // Create an access token with name "Test"
token, err := db.Create(ctx, 1, "Test") token, err := db.Create(ctx, 1, "Test")
require.NoError(t, err) require.NoError(t, err)

View File

@ -70,19 +70,19 @@ type ActionsStore interface {
var Actions ActionsStore var Actions ActionsStore
var _ ActionsStore = (*actions)(nil) var _ ActionsStore = (*actionsStore)(nil)
type actions struct { type actionsStore struct {
*gorm.DB *gorm.DB
} }
// NewActionsStore returns a persistent interface for actions with given // NewActionsStore returns a persistent interface for actions with given
// database connection. // database connection.
func NewActionsStore(db *gorm.DB) ActionsStore { func NewActionsStore(db *gorm.DB) ActionsStore {
return &actions{DB: db} return &actionsStore{DB: db}
} }
func (db *actions) listByOrganization(ctx context.Context, orgID, actorID, afterID int64) *gorm.DB { func (s *actionsStore) listByOrganization(ctx context.Context, orgID, actorID, afterID int64) *gorm.DB {
/* /*
Equivalent SQL for PostgreSQL: Equivalent SQL for PostgreSQL:
@ -102,18 +102,18 @@ func (db *actions) listByOrganization(ctx context.Context, orgID, actorID, after
ORDER BY id DESC ORDER BY id DESC
LIMIT @limit LIMIT @limit
*/ */
return db.WithContext(ctx). return s.WithContext(ctx).
Where("user_id = ?", orgID). Where("user_id = ?", orgID).
Where(db. Where(s.
// Not apply when afterID is not given // Not apply when afterID is not given
Where("?", afterID <= 0). Where("?", afterID <= 0).
Or("id < ?", afterID), Or("id < ?", afterID),
). ).
Where("repo_id IN (?)", db. Where("repo_id IN (?)", s.
Select("repository.id"). Select("repository.id").
Table("repository"). Table("repository").
Joins("JOIN team_repo ON repository.id = team_repo.repo_id"). Joins("JOIN team_repo ON repository.id = team_repo.repo_id").
Where("team_repo.team_id IN (?)", db. Where("team_repo.team_id IN (?)", s.
Select("team_id"). Select("team_id").
Table("team_user"). Table("team_user").
Where("team_user.org_id = ? AND uid = ?", orgID, actorID), Where("team_user.org_id = ? AND uid = ?", orgID, actorID),
@ -124,12 +124,12 @@ func (db *actions) listByOrganization(ctx context.Context, orgID, actorID, after
Order("id DESC") Order("id DESC")
} }
func (db *actions) ListByOrganization(ctx context.Context, orgID, actorID, afterID int64) ([]*Action, error) { func (s *actionsStore) ListByOrganization(ctx context.Context, orgID, actorID, afterID int64) ([]*Action, error) {
actions := make([]*Action, 0, conf.UI.User.NewsFeedPagingNum) actions := make([]*Action, 0, conf.UI.User.NewsFeedPagingNum)
return actions, db.listByOrganization(ctx, orgID, actorID, afterID).Find(&actions).Error return actions, s.listByOrganization(ctx, orgID, actorID, afterID).Find(&actions).Error
} }
func (db *actions) listByUser(ctx context.Context, userID, actorID, afterID int64, isProfile bool) *gorm.DB { func (s *actionsStore) listByUser(ctx context.Context, userID, actorID, afterID int64, isProfile bool) *gorm.DB {
/* /*
Equivalent SQL for PostgreSQL: Equivalent SQL for PostgreSQL:
@ -141,14 +141,14 @@ func (db *actions) listByUser(ctx context.Context, userID, actorID, afterID int6
ORDER BY id DESC ORDER BY id DESC
LIMIT @limit LIMIT @limit
*/ */
return db.WithContext(ctx). return s.WithContext(ctx).
Where("user_id = ?", userID). Where("user_id = ?", userID).
Where(db. Where(s.
// Not apply when afterID is not given // Not apply when afterID is not given
Where("?", afterID <= 0). Where("?", afterID <= 0).
Or("id < ?", afterID), Or("id < ?", afterID),
). ).
Where(db. Where(s.
// Not apply when in not profile page or the user is viewing own profile // Not apply when in not profile page or the user is viewing own profile
Where("?", !isProfile || actorID == userID). Where("?", !isProfile || actorID == userID).
Or("is_private = ? AND act_user_id = ?", false, userID), Or("is_private = ? AND act_user_id = ?", false, userID),
@ -157,14 +157,14 @@ func (db *actions) listByUser(ctx context.Context, userID, actorID, afterID int6
Order("id DESC") Order("id DESC")
} }
func (db *actions) ListByUser(ctx context.Context, userID, actorID, afterID int64, isProfile bool) ([]*Action, error) { func (s *actionsStore) ListByUser(ctx context.Context, userID, actorID, afterID int64, isProfile bool) ([]*Action, error) {
actions := make([]*Action, 0, conf.UI.User.NewsFeedPagingNum) actions := make([]*Action, 0, conf.UI.User.NewsFeedPagingNum)
return actions, db.listByUser(ctx, userID, actorID, afterID, isProfile).Find(&actions).Error return actions, s.listByUser(ctx, userID, actorID, afterID, isProfile).Find(&actions).Error
} }
// notifyWatchers creates rows in action table for watchers who are able to see the action. // notifyWatchers creates rows in action table for watchers who are able to see the action.
func (db *actions) notifyWatchers(ctx context.Context, act *Action) error { func (s *actionsStore) notifyWatchers(ctx context.Context, act *Action) error {
watches, err := NewReposStore(db.DB).ListWatches(ctx, act.RepoID) watches, err := NewReposStore(s.DB).ListWatches(ctx, act.RepoID)
if err != nil { if err != nil {
return errors.Wrap(err, "list watches") return errors.Wrap(err, "list watches")
} }
@ -187,16 +187,16 @@ func (db *actions) notifyWatchers(ctx context.Context, act *Action) error {
actions = append(actions, clone(watch.UserID)) actions = append(actions, clone(watch.UserID))
} }
return db.Create(actions).Error return s.Create(actions).Error
} }
func (db *actions) NewRepo(ctx context.Context, doer, owner *User, repo *Repository) error { func (s *actionsStore) NewRepo(ctx context.Context, doer, owner *User, repo *Repository) error {
opType := ActionCreateRepo opType := ActionCreateRepo
if repo.IsFork { if repo.IsFork {
opType = ActionForkRepo opType = ActionForkRepo
} }
return db.notifyWatchers(ctx, return s.notifyWatchers(ctx,
&Action{ &Action{
ActUserID: doer.ID, ActUserID: doer.ID,
ActUserName: doer.Name, ActUserName: doer.Name,
@ -209,8 +209,8 @@ func (db *actions) NewRepo(ctx context.Context, doer, owner *User, repo *Reposit
) )
} }
func (db *actions) RenameRepo(ctx context.Context, doer, owner *User, oldRepoName string, repo *Repository) error { func (s *actionsStore) RenameRepo(ctx context.Context, doer, owner *User, oldRepoName string, repo *Repository) error {
return db.notifyWatchers(ctx, return s.notifyWatchers(ctx,
&Action{ &Action{
ActUserID: doer.ID, ActUserID: doer.ID,
ActUserName: doer.Name, ActUserName: doer.Name,
@ -224,8 +224,8 @@ func (db *actions) RenameRepo(ctx context.Context, doer, owner *User, oldRepoNam
) )
} }
func (db *actions) mirrorSyncAction(ctx context.Context, opType ActionType, owner *User, repo *Repository, refName string, content []byte) error { func (s *actionsStore) mirrorSyncAction(ctx context.Context, opType ActionType, owner *User, repo *Repository, refName string, content []byte) error {
return db.notifyWatchers(ctx, return s.notifyWatchers(ctx,
&Action{ &Action{
ActUserID: owner.ID, ActUserID: owner.ID,
ActUserName: owner.Name, ActUserName: owner.Name,
@ -249,13 +249,13 @@ type MirrorSyncPushOptions struct {
Commits *PushCommits Commits *PushCommits
} }
func (db *actions) MirrorSyncPush(ctx context.Context, opts MirrorSyncPushOptions) error { func (s *actionsStore) MirrorSyncPush(ctx context.Context, opts MirrorSyncPushOptions) error {
if conf.UI.FeedMaxCommitNum > 0 && len(opts.Commits.Commits) > conf.UI.FeedMaxCommitNum { if conf.UI.FeedMaxCommitNum > 0 && len(opts.Commits.Commits) > conf.UI.FeedMaxCommitNum {
opts.Commits.Commits = opts.Commits.Commits[:conf.UI.FeedMaxCommitNum] opts.Commits.Commits = opts.Commits.Commits[:conf.UI.FeedMaxCommitNum]
} }
apiCommits, err := opts.Commits.APIFormat(ctx, apiCommits, err := opts.Commits.APIFormat(ctx,
NewUsersStore(db.DB), NewUsersStore(s.DB),
repoutil.RepositoryPath(opts.Owner.Name, opts.Repo.Name), repoutil.RepositoryPath(opts.Owner.Name, opts.Repo.Name),
repoutil.HTMLURL(opts.Owner.Name, opts.Repo.Name), repoutil.HTMLURL(opts.Owner.Name, opts.Repo.Name),
) )
@ -288,19 +288,19 @@ func (db *actions) MirrorSyncPush(ctx context.Context, opts MirrorSyncPushOption
return errors.Wrap(err, "marshal JSON") return errors.Wrap(err, "marshal JSON")
} }
return db.mirrorSyncAction(ctx, ActionMirrorSyncPush, opts.Owner, opts.Repo, opts.RefName, data) return s.mirrorSyncAction(ctx, ActionMirrorSyncPush, opts.Owner, opts.Repo, opts.RefName, data)
} }
func (db *actions) MirrorSyncCreate(ctx context.Context, owner *User, repo *Repository, refName string) error { func (s *actionsStore) MirrorSyncCreate(ctx context.Context, owner *User, repo *Repository, refName string) error {
return db.mirrorSyncAction(ctx, ActionMirrorSyncCreate, owner, repo, refName, nil) return s.mirrorSyncAction(ctx, ActionMirrorSyncCreate, owner, repo, refName, nil)
} }
func (db *actions) MirrorSyncDelete(ctx context.Context, owner *User, repo *Repository, refName string) error { func (s *actionsStore) MirrorSyncDelete(ctx context.Context, owner *User, repo *Repository, refName string) error {
return db.mirrorSyncAction(ctx, ActionMirrorSyncDelete, owner, repo, refName, nil) return s.mirrorSyncAction(ctx, ActionMirrorSyncDelete, owner, repo, refName, nil)
} }
func (db *actions) MergePullRequest(ctx context.Context, doer, owner *User, repo *Repository, pull *Issue) error { func (s *actionsStore) MergePullRequest(ctx context.Context, doer, owner *User, repo *Repository, pull *Issue) error {
return db.notifyWatchers(ctx, return s.notifyWatchers(ctx,
&Action{ &Action{
ActUserID: doer.ID, ActUserID: doer.ID,
ActUserName: doer.Name, ActUserName: doer.Name,
@ -314,8 +314,8 @@ func (db *actions) MergePullRequest(ctx context.Context, doer, owner *User, repo
) )
} }
func (db *actions) TransferRepo(ctx context.Context, doer, oldOwner, newOwner *User, repo *Repository) error { func (s *actionsStore) TransferRepo(ctx context.Context, doer, oldOwner, newOwner *User, repo *Repository) error {
return db.notifyWatchers(ctx, return s.notifyWatchers(ctx,
&Action{ &Action{
ActUserID: doer.ID, ActUserID: doer.ID,
ActUserName: doer.Name, ActUserName: doer.Name,
@ -487,13 +487,13 @@ type CommitRepoOptions struct {
Commits *PushCommits Commits *PushCommits
} }
func (db *actions) CommitRepo(ctx context.Context, opts CommitRepoOptions) error { func (s *actionsStore) CommitRepo(ctx context.Context, opts CommitRepoOptions) error {
err := NewReposStore(db.DB).Touch(ctx, opts.Repo.ID) err := NewReposStore(s.DB).Touch(ctx, opts.Repo.ID)
if err != nil { if err != nil {
return errors.Wrap(err, "touch repository") return errors.Wrap(err, "touch repository")
} }
pusher, err := NewUsersStore(db.DB).GetByUsername(ctx, opts.PusherName) pusher, err := NewUsersStore(s.DB).GetByUsername(ctx, opts.PusherName)
if err != nil { if err != nil {
return errors.Wrapf(err, "get pusher [name: %s]", opts.PusherName) return errors.Wrapf(err, "get pusher [name: %s]", opts.PusherName)
} }
@ -536,7 +536,7 @@ func (db *actions) CommitRepo(ctx context.Context, opts CommitRepoOptions) error
} }
action.OpType = ActionDeleteBranch action.OpType = ActionDeleteBranch
err = db.notifyWatchers(ctx, action) err = s.notifyWatchers(ctx, action)
if err != nil { if err != nil {
return errors.Wrap(err, "notify watchers") return errors.Wrap(err, "notify watchers")
} }
@ -580,7 +580,7 @@ func (db *actions) CommitRepo(ctx context.Context, opts CommitRepoOptions) error
} }
action.OpType = ActionCreateBranch action.OpType = ActionCreateBranch
err = db.notifyWatchers(ctx, action) err = s.notifyWatchers(ctx, action)
if err != nil { if err != nil {
return errors.Wrap(err, "notify watchers") return errors.Wrap(err, "notify watchers")
} }
@ -589,7 +589,7 @@ func (db *actions) CommitRepo(ctx context.Context, opts CommitRepoOptions) error
} }
commits, err := opts.Commits.APIFormat(ctx, commits, err := opts.Commits.APIFormat(ctx,
NewUsersStore(db.DB), NewUsersStore(s.DB),
repoutil.RepositoryPath(opts.Owner.Name, opts.Repo.Name), repoutil.RepositoryPath(opts.Owner.Name, opts.Repo.Name),
repoutil.HTMLURL(opts.Owner.Name, opts.Repo.Name), repoutil.HTMLURL(opts.Owner.Name, opts.Repo.Name),
) )
@ -616,7 +616,7 @@ func (db *actions) CommitRepo(ctx context.Context, opts CommitRepoOptions) error
} }
action.OpType = ActionCommitRepo action.OpType = ActionCommitRepo
err = db.notifyWatchers(ctx, action) err = s.notifyWatchers(ctx, action)
if err != nil { if err != nil {
return errors.Wrap(err, "notify watchers") return errors.Wrap(err, "notify watchers")
} }
@ -631,13 +631,13 @@ type PushTagOptions struct {
NewCommitID string NewCommitID string
} }
func (db *actions) PushTag(ctx context.Context, opts PushTagOptions) error { func (s *actionsStore) PushTag(ctx context.Context, opts PushTagOptions) error {
err := NewReposStore(db.DB).Touch(ctx, opts.Repo.ID) err := NewReposStore(s.DB).Touch(ctx, opts.Repo.ID)
if err != nil { if err != nil {
return errors.Wrap(err, "touch repository") return errors.Wrap(err, "touch repository")
} }
pusher, err := NewUsersStore(db.DB).GetByUsername(ctx, opts.PusherName) pusher, err := NewUsersStore(s.DB).GetByUsername(ctx, opts.PusherName)
if err != nil { if err != nil {
return errors.Wrapf(err, "get pusher [name: %s]", opts.PusherName) return errors.Wrapf(err, "get pusher [name: %s]", opts.PusherName)
} }
@ -672,7 +672,7 @@ func (db *actions) PushTag(ctx context.Context, opts PushTagOptions) error {
} }
action.OpType = ActionDeleteTag action.OpType = ActionDeleteTag
err = db.notifyWatchers(ctx, action) err = s.notifyWatchers(ctx, action)
if err != nil { if err != nil {
return errors.Wrap(err, "notify watchers") return errors.Wrap(err, "notify watchers")
} }
@ -696,7 +696,7 @@ func (db *actions) PushTag(ctx context.Context, opts PushTagOptions) error {
} }
action.OpType = ActionPushTag action.OpType = ActionPushTag
err = db.notifyWatchers(ctx, action) err = s.notifyWatchers(ctx, action)
if err != nil { if err != nil {
return errors.Wrap(err, "notify watchers") return errors.Wrap(err, "notify watchers")
} }

View File

@ -99,13 +99,13 @@ func TestActions(t *testing.T) {
ctx := context.Background() ctx := context.Background()
t.Parallel() t.Parallel()
db := &actions{ db := &actionsStore{
DB: newTestDB(t, "actions"), DB: newTestDB(t, "actionsStore"),
} }
for _, tc := range []struct { for _, tc := range []struct {
name string name string
test func(t *testing.T, ctx context.Context, db *actions) test func(t *testing.T, ctx context.Context, db *actionsStore)
}{ }{
{"CommitRepo", actionsCommitRepo}, {"CommitRepo", actionsCommitRepo},
{"ListByOrganization", actionsListByOrganization}, {"ListByOrganization", actionsListByOrganization},
@ -132,7 +132,7 @@ func TestActions(t *testing.T) {
} }
} }
func actionsCommitRepo(t *testing.T, ctx context.Context, db *actions) { func actionsCommitRepo(t *testing.T, ctx context.Context, db *actionsStore) {
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
repo, err := NewReposStore(db.DB).Create(ctx, repo, err := NewReposStore(db.DB).Create(ctx,
@ -324,7 +324,7 @@ func actionsCommitRepo(t *testing.T, ctx context.Context, db *actions) {
}) })
} }
func actionsListByOrganization(t *testing.T, ctx context.Context, db *actions) { func actionsListByOrganization(t *testing.T, ctx context.Context, db *actionsStore) {
if os.Getenv("GOGS_DATABASE_TYPE") != "postgres" { if os.Getenv("GOGS_DATABASE_TYPE") != "postgres" {
t.Skip("Skipping testing with not using PostgreSQL") t.Skip("Skipping testing with not using PostgreSQL")
return return
@ -363,14 +363,14 @@ func actionsListByOrganization(t *testing.T, ctx context.Context, db *actions) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
got := db.DB.ToSQL(func(tx *gorm.DB) *gorm.DB { got := db.DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
return NewActionsStore(tx).(*actions).listByOrganization(ctx, test.orgID, test.actorID, test.afterID).Find(new(Action)) return NewActionsStore(tx).(*actionsStore).listByOrganization(ctx, test.orgID, test.actorID, test.afterID).Find(new(Action))
}) })
assert.Equal(t, test.want, got) assert.Equal(t, test.want, got)
}) })
} }
} }
func actionsListByUser(t *testing.T, ctx context.Context, db *actions) { func actionsListByUser(t *testing.T, ctx context.Context, db *actionsStore) {
if os.Getenv("GOGS_DATABASE_TYPE") != "postgres" { if os.Getenv("GOGS_DATABASE_TYPE") != "postgres" {
t.Skip("Skipping testing with not using PostgreSQL") t.Skip("Skipping testing with not using PostgreSQL")
return return
@ -428,14 +428,14 @@ func actionsListByUser(t *testing.T, ctx context.Context, db *actions) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
got := db.DB.ToSQL(func(tx *gorm.DB) *gorm.DB { got := db.DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
return NewActionsStore(tx).(*actions).listByUser(ctx, test.userID, test.actorID, test.afterID, test.isProfile).Find(new(Action)) return NewActionsStore(tx).(*actionsStore).listByUser(ctx, test.userID, test.actorID, test.afterID, test.isProfile).Find(new(Action))
}) })
assert.Equal(t, test.want, got) assert.Equal(t, test.want, got)
}) })
} }
} }
func actionsMergePullRequest(t *testing.T, ctx context.Context, db *actions) { func actionsMergePullRequest(t *testing.T, ctx context.Context, db *actionsStore) {
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
repo, err := NewReposStore(db.DB).Create(ctx, repo, err := NewReposStore(db.DB).Create(ctx,
@ -480,7 +480,7 @@ func actionsMergePullRequest(t *testing.T, ctx context.Context, db *actions) {
assert.Equal(t, want, got) assert.Equal(t, want, got)
} }
func actionsMirrorSyncCreate(t *testing.T, ctx context.Context, db *actions) { func actionsMirrorSyncCreate(t *testing.T, ctx context.Context, db *actionsStore) {
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
repo, err := NewReposStore(db.DB).Create(ctx, repo, err := NewReposStore(db.DB).Create(ctx,
@ -521,7 +521,7 @@ func actionsMirrorSyncCreate(t *testing.T, ctx context.Context, db *actions) {
assert.Equal(t, want, got) assert.Equal(t, want, got)
} }
func actionsMirrorSyncDelete(t *testing.T, ctx context.Context, db *actions) { func actionsMirrorSyncDelete(t *testing.T, ctx context.Context, db *actionsStore) {
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
repo, err := NewReposStore(db.DB).Create(ctx, repo, err := NewReposStore(db.DB).Create(ctx,
@ -562,7 +562,7 @@ func actionsMirrorSyncDelete(t *testing.T, ctx context.Context, db *actions) {
assert.Equal(t, want, got) assert.Equal(t, want, got)
} }
func actionsMirrorSyncPush(t *testing.T, ctx context.Context, db *actions) { func actionsMirrorSyncPush(t *testing.T, ctx context.Context, db *actionsStore) {
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
repo, err := NewReposStore(db.DB).Create(ctx, repo, err := NewReposStore(db.DB).Create(ctx,
@ -627,7 +627,7 @@ func actionsMirrorSyncPush(t *testing.T, ctx context.Context, db *actions) {
assert.Equal(t, want, got) assert.Equal(t, want, got)
} }
func actionsNewRepo(t *testing.T, ctx context.Context, db *actions) { func actionsNewRepo(t *testing.T, ctx context.Context, db *actionsStore) {
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
repo, err := NewReposStore(db.DB).Create(ctx, repo, err := NewReposStore(db.DB).Create(ctx,
@ -702,7 +702,7 @@ func actionsNewRepo(t *testing.T, ctx context.Context, db *actions) {
}) })
} }
func actionsPushTag(t *testing.T, ctx context.Context, db *actions) { func actionsPushTag(t *testing.T, ctx context.Context, db *actionsStore) {
// NOTE: We set a noop mock here to avoid data race with other tests that writes // NOTE: We set a noop mock here to avoid data race with other tests that writes
// to the mock server because this function holds a lock. // to the mock server because this function holds a lock.
conf.SetMockServer(t, conf.ServerOpts{}) conf.SetMockServer(t, conf.ServerOpts{})
@ -798,7 +798,7 @@ func actionsPushTag(t *testing.T, ctx context.Context, db *actions) {
}) })
} }
func actionsRenameRepo(t *testing.T, ctx context.Context, db *actions) { func actionsRenameRepo(t *testing.T, ctx context.Context, db *actionsStore) {
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
repo, err := NewReposStore(db.DB).Create(ctx, repo, err := NewReposStore(db.DB).Create(ctx,
@ -835,7 +835,7 @@ func actionsRenameRepo(t *testing.T, ctx context.Context, db *actions) {
assert.Equal(t, want, got) assert.Equal(t, want, got)
} }
func actionsTransferRepo(t *testing.T, ctx context.Context, db *actions) { func actionsTransferRepo(t *testing.T, ctx context.Context, db *actionsStore) {
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
bob, err := NewUsersStore(db.DB).Create(ctx, "bob", "bob@example.com", CreateUserOptions{}) bob, err := NewUsersStore(db.DB).Create(ctx, "bob", "bob@example.com", CreateUserOptions{})

View File

@ -123,15 +123,15 @@ func Init(w logger.Writer) (*gorm.DB, error) {
} }
// Initialize stores, sorted in alphabetical order. // Initialize stores, sorted in alphabetical order.
AccessTokens = &accessTokens{DB: db} AccessTokens = &accessTokensStore{DB: db}
Actions = NewActionsStore(db) Actions = NewActionsStore(db)
LoginSources = &loginSources{DB: db, files: sourceFiles} LoginSources = &loginSourcesStore{DB: db, files: sourceFiles}
LFS = &lfs{DB: db} LFS = &lfsStore{DB: db}
Notices = NewNoticesStore(db) Notices = NewNoticesStore(db)
Orgs = NewOrgsStore(db) Orgs = NewOrgsStore(db)
Perms = NewPermsStore(db) Perms = NewPermsStore(db)
Repos = NewReposStore(db) Repos = NewReposStore(db)
TwoFactors = &twoFactors{DB: db} TwoFactors = &twoFactorsStore{DB: db}
Users = NewUsersStore(db) Users = NewUsersStore(db)
return db, nil return db, nil

View File

@ -38,20 +38,20 @@ type LFSObject struct {
CreatedAt time.Time `gorm:"not null"` CreatedAt time.Time `gorm:"not null"`
} }
var _ LFSStore = (*lfs)(nil) var _ LFSStore = (*lfsStore)(nil)
type lfs struct { type lfsStore struct {
*gorm.DB *gorm.DB
} }
func (db *lfs) CreateObject(ctx context.Context, repoID int64, oid lfsutil.OID, size int64, storage lfsutil.Storage) error { func (s *lfsStore) CreateObject(ctx context.Context, repoID int64, oid lfsutil.OID, size int64, storage lfsutil.Storage) error {
object := &LFSObject{ object := &LFSObject{
RepoID: repoID, RepoID: repoID,
OID: oid, OID: oid,
Size: size, Size: size,
Storage: storage, Storage: storage,
} }
return db.WithContext(ctx).Create(object).Error return s.WithContext(ctx).Create(object).Error
} }
type ErrLFSObjectNotExist struct { type ErrLFSObjectNotExist struct {
@ -71,9 +71,9 @@ func (ErrLFSObjectNotExist) NotFound() bool {
return true return true
} }
func (db *lfs) GetObjectByOID(ctx context.Context, repoID int64, oid lfsutil.OID) (*LFSObject, error) { func (s *lfsStore) GetObjectByOID(ctx context.Context, repoID int64, oid lfsutil.OID) (*LFSObject, error) {
object := new(LFSObject) object := new(LFSObject)
err := db.WithContext(ctx).Where("repo_id = ? AND oid = ?", repoID, oid).First(object).Error err := s.WithContext(ctx).Where("repo_id = ? AND oid = ?", repoID, oid).First(object).Error
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if err == gorm.ErrRecordNotFound {
return nil, ErrLFSObjectNotExist{args: errutil.Args{"repoID": repoID, "oid": oid}} return nil, ErrLFSObjectNotExist{args: errutil.Args{"repoID": repoID, "oid": oid}}
@ -83,13 +83,13 @@ func (db *lfs) GetObjectByOID(ctx context.Context, repoID int64, oid lfsutil.OID
return object, err return object, err
} }
func (db *lfs) GetObjectsByOIDs(ctx context.Context, repoID int64, oids ...lfsutil.OID) ([]*LFSObject, error) { func (s *lfsStore) GetObjectsByOIDs(ctx context.Context, repoID int64, oids ...lfsutil.OID) ([]*LFSObject, error) {
if len(oids) == 0 { if len(oids) == 0 {
return []*LFSObject{}, nil return []*LFSObject{}, nil
} }
objects := make([]*LFSObject, 0, len(oids)) objects := make([]*LFSObject, 0, len(oids))
err := db.WithContext(ctx).Where("repo_id = ? AND oid IN (?)", repoID, oids).Find(&objects).Error err := s.WithContext(ctx).Where("repo_id = ? AND oid IN (?)", repoID, oids).Find(&objects).Error
if err != nil && err != gorm.ErrRecordNotFound { if err != nil && err != gorm.ErrRecordNotFound {
return nil, err return nil, err
} }

View File

@ -23,13 +23,13 @@ func TestLFS(t *testing.T) {
t.Parallel() t.Parallel()
ctx := context.Background() ctx := context.Background()
db := &lfs{ db := &lfsStore{
DB: newTestDB(t, "lfs"), DB: newTestDB(t, "lfsStore"),
} }
for _, tc := range []struct { for _, tc := range []struct {
name string name string
test func(t *testing.T, ctx context.Context, db *lfs) test func(t *testing.T, ctx context.Context, db *lfsStore)
}{ }{
{"CreateObject", lfsCreateObject}, {"CreateObject", lfsCreateObject},
{"GetObjectByOID", lfsGetObjectByOID}, {"GetObjectByOID", lfsGetObjectByOID},
@ -48,7 +48,7 @@ func TestLFS(t *testing.T) {
} }
} }
func lfsCreateObject(t *testing.T, ctx context.Context, db *lfs) { func lfsCreateObject(t *testing.T, ctx context.Context, db *lfsStore) {
// Create first LFS object // Create first LFS object
repoID := int64(1) repoID := int64(1)
oid := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f") oid := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f")
@ -65,7 +65,7 @@ func lfsCreateObject(t *testing.T, ctx context.Context, db *lfs) {
assert.Error(t, err) assert.Error(t, err)
} }
func lfsGetObjectByOID(t *testing.T, ctx context.Context, db *lfs) { func lfsGetObjectByOID(t *testing.T, ctx context.Context, db *lfsStore) {
// Create a LFS object // Create a LFS object
repoID := int64(1) repoID := int64(1)
oid := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f") oid := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f")
@ -82,7 +82,7 @@ func lfsGetObjectByOID(t *testing.T, ctx context.Context, db *lfs) {
assert.Equal(t, expErr, err) assert.Equal(t, expErr, err)
} }
func lfsGetObjectsByOIDs(t *testing.T, ctx context.Context, db *lfs) { func lfsGetObjectsByOIDs(t *testing.T, ctx context.Context, db *lfsStore) {
// Create two LFS objects // Create two LFS objects
repoID := int64(1) repoID := int64(1)
oid1 := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f") oid1 := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f")

View File

@ -180,9 +180,9 @@ func (s *LoginSource) GitHub() *github.Config {
return s.Provider.Config().(*github.Config) return s.Provider.Config().(*github.Config)
} }
var _ LoginSourcesStore = (*loginSources)(nil) var _ LoginSourcesStore = (*loginSourcesStore)(nil)
type loginSources struct { type loginSourcesStore struct {
*gorm.DB *gorm.DB
files loginSourceFilesStore files loginSourceFilesStore
} }
@ -208,8 +208,8 @@ func (err ErrLoginSourceAlreadyExist) Error() string {
return fmt.Sprintf("login source already exists: %v", err.args) return fmt.Sprintf("login source already exists: %v", err.args)
} }
func (db *loginSources) Create(ctx context.Context, opts CreateLoginSourceOptions) (*LoginSource, error) { func (s *loginSourcesStore) Create(ctx context.Context, opts CreateLoginSourceOptions) (*LoginSource, error) {
err := db.WithContext(ctx).Where("name = ?", opts.Name).First(new(LoginSource)).Error err := s.WithContext(ctx).Where("name = ?", opts.Name).First(new(LoginSource)).Error
if err == nil { if err == nil {
return nil, ErrLoginSourceAlreadyExist{args: errutil.Args{"name": opts.Name}} return nil, ErrLoginSourceAlreadyExist{args: errutil.Args{"name": opts.Name}}
} else if err != gorm.ErrRecordNotFound { } else if err != gorm.ErrRecordNotFound {
@ -226,13 +226,13 @@ func (db *loginSources) Create(ctx context.Context, opts CreateLoginSourceOption
if err != nil { if err != nil {
return nil, err return nil, err
} }
return source, db.WithContext(ctx).Create(source).Error return source, s.WithContext(ctx).Create(source).Error
} }
func (db *loginSources) Count(ctx context.Context) int64 { func (s *loginSourcesStore) Count(ctx context.Context) int64 {
var count int64 var count int64
db.WithContext(ctx).Model(new(LoginSource)).Count(&count) s.WithContext(ctx).Model(new(LoginSource)).Count(&count)
return count + int64(db.files.Len()) return count + int64(s.files.Len())
} }
type ErrLoginSourceInUse struct { type ErrLoginSourceInUse struct {
@ -248,24 +248,24 @@ func (err ErrLoginSourceInUse) Error() string {
return fmt.Sprintf("login source is still used by some users: %v", err.args) return fmt.Sprintf("login source is still used by some users: %v", err.args)
} }
func (db *loginSources) DeleteByID(ctx context.Context, id int64) error { func (s *loginSourcesStore) DeleteByID(ctx context.Context, id int64) error {
var count int64 var count int64
err := db.WithContext(ctx).Model(new(User)).Where("login_source = ?", id).Count(&count).Error err := s.WithContext(ctx).Model(new(User)).Where("login_source = ?", id).Count(&count).Error
if err != nil { if err != nil {
return err return err
} else if count > 0 { } else if count > 0 {
return ErrLoginSourceInUse{args: errutil.Args{"id": id}} return ErrLoginSourceInUse{args: errutil.Args{"id": id}}
} }
return db.WithContext(ctx).Where("id = ?", id).Delete(new(LoginSource)).Error return s.WithContext(ctx).Where("id = ?", id).Delete(new(LoginSource)).Error
} }
func (db *loginSources) GetByID(ctx context.Context, id int64) (*LoginSource, error) { func (s *loginSourcesStore) GetByID(ctx context.Context, id int64) (*LoginSource, error) {
source := new(LoginSource) source := new(LoginSource)
err := db.WithContext(ctx).Where("id = ?", id).First(source).Error err := s.WithContext(ctx).Where("id = ?", id).First(source).Error
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if err == gorm.ErrRecordNotFound {
return db.files.GetByID(id) return s.files.GetByID(id)
} }
return nil, err return nil, err
} }
@ -277,9 +277,9 @@ type ListLoginSourceOptions struct {
OnlyActivated bool OnlyActivated bool
} }
func (db *loginSources) List(ctx context.Context, opts ListLoginSourceOptions) ([]*LoginSource, error) { func (s *loginSourcesStore) List(ctx context.Context, opts ListLoginSourceOptions) ([]*LoginSource, error) {
var sources []*LoginSource var sources []*LoginSource
query := db.WithContext(ctx).Order("id ASC") query := s.WithContext(ctx).Order("id ASC")
if opts.OnlyActivated { if opts.OnlyActivated {
query = query.Where("is_actived = ?", true) query = query.Where("is_actived = ?", true)
} }
@ -288,11 +288,11 @@ func (db *loginSources) List(ctx context.Context, opts ListLoginSourceOptions) (
return nil, err return nil, err
} }
return append(sources, db.files.List(opts)...), nil return append(sources, s.files.List(opts)...), nil
} }
func (db *loginSources) ResetNonDefault(ctx context.Context, dflt *LoginSource) error { func (s *loginSourcesStore) ResetNonDefault(ctx context.Context, dflt *LoginSource) error {
err := db.WithContext(ctx). err := s.WithContext(ctx).
Model(new(LoginSource)). Model(new(LoginSource)).
Where("id != ?", dflt.ID). Where("id != ?", dflt.ID).
Updates(map[string]any{"is_default": false}). Updates(map[string]any{"is_default": false}).
@ -301,7 +301,7 @@ func (db *loginSources) ResetNonDefault(ctx context.Context, dflt *LoginSource)
return err return err
} }
for _, source := range db.files.List(ListLoginSourceOptions{}) { for _, source := range s.files.List(ListLoginSourceOptions{}) {
if source.File != nil && source.ID != dflt.ID { if source.File != nil && source.ID != dflt.ID {
source.File.SetGeneral("is_default", "false") source.File.SetGeneral("is_default", "false")
if err = source.File.Save(); err != nil { if err = source.File.Save(); err != nil {
@ -310,13 +310,13 @@ func (db *loginSources) ResetNonDefault(ctx context.Context, dflt *LoginSource)
} }
} }
db.files.Update(dflt) s.files.Update(dflt)
return nil return nil
} }
func (db *loginSources) Save(ctx context.Context, source *LoginSource) error { func (s *loginSourcesStore) Save(ctx context.Context, source *LoginSource) error {
if source.File == nil { if source.File == nil {
return db.WithContext(ctx).Save(source).Error return s.WithContext(ctx).Save(source).Error
} }
source.File.SetGeneral("name", source.Name) source.File.SetGeneral("name", source.Name)

View File

@ -163,13 +163,13 @@ func TestLoginSources(t *testing.T) {
t.Parallel() t.Parallel()
ctx := context.Background() ctx := context.Background()
db := &loginSources{ db := &loginSourcesStore{
DB: newTestDB(t, "loginSources"), DB: newTestDB(t, "loginSourcesStore"),
} }
for _, tc := range []struct { for _, tc := range []struct {
name string name string
test func(t *testing.T, ctx context.Context, db *loginSources) test func(t *testing.T, ctx context.Context, db *loginSourcesStore)
}{ }{
{"Create", loginSourcesCreate}, {"Create", loginSourcesCreate},
{"Count", loginSourcesCount}, {"Count", loginSourcesCount},
@ -192,7 +192,7 @@ func TestLoginSources(t *testing.T) {
} }
} }
func loginSourcesCreate(t *testing.T, ctx context.Context, db *loginSources) { func loginSourcesCreate(t *testing.T, ctx context.Context, db *loginSourcesStore) {
// Create first login source with name "GitHub" // Create first login source with name "GitHub"
source, err := db.Create(ctx, source, err := db.Create(ctx,
CreateLoginSourceOptions{ CreateLoginSourceOptions{
@ -219,7 +219,7 @@ func loginSourcesCreate(t *testing.T, ctx context.Context, db *loginSources) {
assert.Equal(t, wantErr, err) assert.Equal(t, wantErr, err)
} }
func loginSourcesCount(t *testing.T, ctx context.Context, db *loginSources) { func loginSourcesCount(t *testing.T, ctx context.Context, db *loginSourcesStore) {
// Create two login sources, one in database and one as source file. // Create two login sources, one in database and one as source file.
_, err := db.Create(ctx, _, err := db.Create(ctx,
CreateLoginSourceOptions{ CreateLoginSourceOptions{
@ -241,7 +241,7 @@ func loginSourcesCount(t *testing.T, ctx context.Context, db *loginSources) {
assert.Equal(t, int64(3), db.Count(ctx)) assert.Equal(t, int64(3), db.Count(ctx))
} }
func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSources) { func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSourcesStore) {
t.Run("delete but in used", func(t *testing.T) { t.Run("delete but in used", func(t *testing.T) {
source, err := db.Create(ctx, source, err := db.Create(ctx,
CreateLoginSourceOptions{ CreateLoginSourceOptions{
@ -257,7 +257,7 @@ func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSources)
require.NoError(t, err) require.NoError(t, err)
// Create a user that uses this login source // Create a user that uses this login source
_, err = (&users{DB: db.DB}).Create(ctx, "alice", "", _, err = (&usersStore{DB: db.DB}).Create(ctx, "alice", "",
CreateUserOptions{ CreateUserOptions{
LoginSource: source.ID, LoginSource: source.ID,
}, },
@ -308,7 +308,7 @@ func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSources)
assert.Equal(t, wantErr, err) assert.Equal(t, wantErr, err)
} }
func loginSourcesGetByID(t *testing.T, ctx context.Context, db *loginSources) { func loginSourcesGetByID(t *testing.T, ctx context.Context, db *loginSourcesStore) {
mock := NewMockLoginSourceFilesStore() mock := NewMockLoginSourceFilesStore()
mock.GetByIDFunc.SetDefaultHook(func(id int64) (*LoginSource, error) { mock.GetByIDFunc.SetDefaultHook(func(id int64) (*LoginSource, error) {
if id != 101 { if id != 101 {
@ -344,7 +344,7 @@ func loginSourcesGetByID(t *testing.T, ctx context.Context, db *loginSources) {
require.NoError(t, err) require.NoError(t, err)
} }
func loginSourcesList(t *testing.T, ctx context.Context, db *loginSources) { func loginSourcesList(t *testing.T, ctx context.Context, db *loginSourcesStore) {
mock := NewMockLoginSourceFilesStore() mock := NewMockLoginSourceFilesStore()
mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource { mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource {
if opts.OnlyActivated { if opts.OnlyActivated {
@ -393,7 +393,7 @@ func loginSourcesList(t *testing.T, ctx context.Context, db *loginSources) {
assert.Equal(t, 2, len(sources), "number of sources") assert.Equal(t, 2, len(sources), "number of sources")
} }
func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSources) { func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSourcesStore) {
mock := NewMockLoginSourceFilesStore() mock := NewMockLoginSourceFilesStore()
mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource { mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource {
mockFile := NewMockLoginSourceFileStore() mockFile := NewMockLoginSourceFileStore()
@ -448,7 +448,7 @@ func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSou
assert.False(t, source2.IsDefault) assert.False(t, source2.IsDefault)
} }
func loginSourcesSave(t *testing.T, ctx context.Context, db *loginSources) { func loginSourcesSave(t *testing.T, ctx context.Context, db *loginSourcesStore) {
t.Run("save to database", func(t *testing.T) { t.Run("save to database", func(t *testing.T) {
// Create a login source with name "GitHub" // Create a login source with name "GitHub"
source, err := db.Create(ctx, source, err := db.Create(ctx,

View File

@ -32,7 +32,7 @@ func setMockLoginSourcesStore(t *testing.T, mock LoginSourcesStore) {
}) })
} }
func setMockLoginSourceFilesStore(t *testing.T, db *loginSources, mock loginSourceFilesStore) { func setMockLoginSourceFilesStore(t *testing.T, db *loginSourcesStore, mock loginSourceFilesStore) {
before := db.files before := db.files
db.files = mock db.files = mock
t.Cleanup(func() { t.Cleanup(func() {

View File

@ -32,20 +32,20 @@ type NoticesStore interface {
var Notices NoticesStore var Notices NoticesStore
var _ NoticesStore = (*notices)(nil) var _ NoticesStore = (*noticesStore)(nil)
type notices struct { type noticesStore struct {
*gorm.DB *gorm.DB
} }
// NewNoticesStore returns a persistent interface for system notices with given // NewNoticesStore returns a persistent interface for system notices with given
// database connection. // database connection.
func NewNoticesStore(db *gorm.DB) NoticesStore { func NewNoticesStore(db *gorm.DB) NoticesStore {
return &notices{DB: db} return &noticesStore{DB: db}
} }
func (db *notices) Create(ctx context.Context, typ NoticeType, desc string) error { func (s *noticesStore) Create(ctx context.Context, typ NoticeType, desc string) error {
return db.WithContext(ctx).Create( return s.WithContext(ctx).Create(
&Notice{ &Notice{
Type: typ, Type: typ,
Description: desc, Description: desc,
@ -53,26 +53,26 @@ func (db *notices) Create(ctx context.Context, typ NoticeType, desc string) erro
).Error ).Error
} }
func (db *notices) DeleteByIDs(ctx context.Context, ids ...int64) error { func (s *noticesStore) DeleteByIDs(ctx context.Context, ids ...int64) error {
return db.WithContext(ctx).Where("id IN (?)", ids).Delete(&Notice{}).Error return s.WithContext(ctx).Where("id IN (?)", ids).Delete(&Notice{}).Error
} }
func (db *notices) DeleteAll(ctx context.Context) error { func (s *noticesStore) DeleteAll(ctx context.Context) error {
return db.WithContext(ctx).Where("TRUE").Delete(&Notice{}).Error return s.WithContext(ctx).Where("TRUE").Delete(&Notice{}).Error
} }
func (db *notices) List(ctx context.Context, page, pageSize int) ([]*Notice, error) { func (s *noticesStore) List(ctx context.Context, page, pageSize int) ([]*Notice, error) {
notices := make([]*Notice, 0, pageSize) notices := make([]*Notice, 0, pageSize)
return notices, db.WithContext(ctx). return notices, s.WithContext(ctx).
Limit(pageSize).Offset((page - 1) * pageSize). Limit(pageSize).Offset((page - 1) * pageSize).
Order("id DESC"). Order("id DESC").
Find(&notices). Find(&notices).
Error Error
} }
func (db *notices) Count(ctx context.Context) int64 { func (s *noticesStore) Count(ctx context.Context) int64 {
var count int64 var count int64
db.WithContext(ctx).Model(&Notice{}).Count(&count) s.WithContext(ctx).Model(&Notice{}).Count(&count)
return count return count
} }

View File

@ -65,13 +65,13 @@ func TestNotices(t *testing.T) {
t.Parallel() t.Parallel()
ctx := context.Background() ctx := context.Background()
db := &notices{ db := &noticesStore{
DB: newTestDB(t, "notices"), DB: newTestDB(t, "noticesStore"),
} }
for _, tc := range []struct { for _, tc := range []struct {
name string name string
test func(t *testing.T, ctx context.Context, db *notices) test func(t *testing.T, ctx context.Context, db *noticesStore)
}{ }{
{"Create", noticesCreate}, {"Create", noticesCreate},
{"DeleteByIDs", noticesDeleteByIDs}, {"DeleteByIDs", noticesDeleteByIDs},
@ -92,7 +92,7 @@ func TestNotices(t *testing.T) {
} }
} }
func noticesCreate(t *testing.T, ctx context.Context, db *notices) { func noticesCreate(t *testing.T, ctx context.Context, db *noticesStore) {
err := db.Create(ctx, NoticeTypeRepository, "test") err := db.Create(ctx, NoticeTypeRepository, "test")
require.NoError(t, err) require.NoError(t, err)
@ -100,7 +100,7 @@ func noticesCreate(t *testing.T, ctx context.Context, db *notices) {
assert.Equal(t, int64(1), count) assert.Equal(t, int64(1), count)
} }
func noticesDeleteByIDs(t *testing.T, ctx context.Context, db *notices) { func noticesDeleteByIDs(t *testing.T, ctx context.Context, db *noticesStore) {
err := db.Create(ctx, NoticeTypeRepository, "test") err := db.Create(ctx, NoticeTypeRepository, "test")
require.NoError(t, err) require.NoError(t, err)
@ -120,7 +120,7 @@ func noticesDeleteByIDs(t *testing.T, ctx context.Context, db *notices) {
assert.Equal(t, int64(0), count) assert.Equal(t, int64(0), count)
} }
func noticesDeleteAll(t *testing.T, ctx context.Context, db *notices) { func noticesDeleteAll(t *testing.T, ctx context.Context, db *noticesStore) {
err := db.Create(ctx, NoticeTypeRepository, "test") err := db.Create(ctx, NoticeTypeRepository, "test")
require.NoError(t, err) require.NoError(t, err)
@ -131,7 +131,7 @@ func noticesDeleteAll(t *testing.T, ctx context.Context, db *notices) {
assert.Equal(t, int64(0), count) assert.Equal(t, int64(0), count)
} }
func noticesList(t *testing.T, ctx context.Context, db *notices) { func noticesList(t *testing.T, ctx context.Context, db *noticesStore) {
err := db.Create(ctx, NoticeTypeRepository, "test 1") err := db.Create(ctx, NoticeTypeRepository, "test 1")
require.NoError(t, err) require.NoError(t, err)
err = db.Create(ctx, NoticeTypeRepository, "test 2") err = db.Create(ctx, NoticeTypeRepository, "test 2")
@ -151,7 +151,7 @@ func noticesList(t *testing.T, ctx context.Context, db *notices) {
require.Len(t, got, 2) require.Len(t, got, 2)
} }
func noticesCount(t *testing.T, ctx context.Context, db *notices) { func noticesCount(t *testing.T, ctx context.Context, db *noticesStore) {
count := db.Count(ctx) count := db.Count(ctx)
assert.Equal(t, int64(0), count) assert.Equal(t, int64(0), count)

View File

@ -30,16 +30,16 @@ type OrgsStore interface {
var Orgs OrgsStore var Orgs OrgsStore
var _ OrgsStore = (*orgs)(nil) var _ OrgsStore = (*orgsStore)(nil)
type orgs struct { type orgsStore struct {
*gorm.DB *gorm.DB
} }
// NewOrgsStore returns a persistent interface for orgs with given database // NewOrgsStore returns a persistent interface for orgs with given database
// connection. // connection.
func NewOrgsStore(db *gorm.DB) OrgsStore { func NewOrgsStore(db *gorm.DB) OrgsStore {
return &orgs{DB: db} return &orgsStore{DB: db}
} }
type ListOrgsOptions struct { type ListOrgsOptions struct {
@ -49,7 +49,7 @@ type ListOrgsOptions struct {
IncludePrivateMembers bool IncludePrivateMembers bool
} }
func (db *orgs) List(ctx context.Context, opts ListOrgsOptions) ([]*Organization, error) { func (s *orgsStore) List(ctx context.Context, opts ListOrgsOptions) ([]*Organization, error) {
if opts.MemberID <= 0 { if opts.MemberID <= 0 {
return nil, errors.New("MemberID must be greater than 0") return nil, errors.New("MemberID must be greater than 0")
} }
@ -64,7 +64,7 @@ func (db *orgs) List(ctx context.Context, opts ListOrgsOptions) ([]*Organization
[AND org_user.is_public = @includePrivateMembers] [AND org_user.is_public = @includePrivateMembers]
ORDER BY org.id ASC ORDER BY org.id ASC
*/ */
tx := db.WithContext(ctx). tx := s.WithContext(ctx).
Joins(dbutil.Quote("JOIN org_user ON org_user.org_id = %s.id", "user")). Joins(dbutil.Quote("JOIN org_user ON org_user.org_id = %s.id", "user")).
Where("org_user.uid = ?", opts.MemberID). Where("org_user.uid = ?", opts.MemberID).
Order(dbutil.Quote("%s.id ASC", "user")) Order(dbutil.Quote("%s.id ASC", "user"))
@ -76,13 +76,13 @@ func (db *orgs) List(ctx context.Context, opts ListOrgsOptions) ([]*Organization
return orgs, tx.Find(&orgs).Error return orgs, tx.Find(&orgs).Error
} }
func (db *orgs) SearchByName(ctx context.Context, keyword string, page, pageSize int, orderBy string) ([]*Organization, int64, error) { func (s *orgsStore) SearchByName(ctx context.Context, keyword string, page, pageSize int, orderBy string) ([]*Organization, int64, error) {
return searchUserByName(ctx, db.DB, UserTypeOrganization, keyword, page, pageSize, orderBy) return searchUserByName(ctx, s.DB, UserTypeOrganization, keyword, page, pageSize, orderBy)
} }
func (db *orgs) CountByUser(ctx context.Context, userID int64) (int64, error) { func (s *orgsStore) CountByUser(ctx context.Context, userID int64) (int64, error) {
var count int64 var count int64
return count, db.WithContext(ctx).Model(&OrgUser{}).Where("uid = ?", userID).Count(&count).Error return count, s.WithContext(ctx).Model(&OrgUser{}).Where("uid = ?", userID).Count(&count).Error
} }
type Organization = User type Organization = User

View File

@ -21,13 +21,13 @@ func TestOrgs(t *testing.T) {
t.Parallel() t.Parallel()
ctx := context.Background() ctx := context.Background()
db := &orgs{ db := &orgsStore{
DB: newTestDB(t, "orgs"), DB: newTestDB(t, "orgsStore"),
} }
for _, tc := range []struct { for _, tc := range []struct {
name string name string
test func(t *testing.T, ctx context.Context, db *orgs) test func(t *testing.T, ctx context.Context, db *orgsStore)
}{ }{
{"List", orgsList}, {"List", orgsList},
{"SearchByName", orgsSearchByName}, {"SearchByName", orgsSearchByName},
@ -46,7 +46,7 @@ func TestOrgs(t *testing.T) {
} }
} }
func orgsList(t *testing.T, ctx context.Context, db *orgs) { func orgsList(t *testing.T, ctx context.Context, db *orgsStore) {
usersStore := NewUsersStore(db.DB) usersStore := NewUsersStore(db.DB)
alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -116,7 +116,7 @@ func orgsList(t *testing.T, ctx context.Context, db *orgs) {
} }
} }
func orgsSearchByName(t *testing.T, ctx context.Context, db *orgs) { func orgsSearchByName(t *testing.T, ctx context.Context, db *orgsStore) {
// TODO: Use Orgs.Create to replace SQL hack when the method is available. // TODO: Use Orgs.Create to replace SQL hack when the method is available.
usersStore := NewUsersStore(db.DB) usersStore := NewUsersStore(db.DB)
org1, err := usersStore.Create(ctx, "org1", "org1@example.com", CreateUserOptions{FullName: "Acme Corp"}) org1, err := usersStore.Create(ctx, "org1", "org1@example.com", CreateUserOptions{FullName: "Acme Corp"})
@ -161,7 +161,7 @@ func orgsSearchByName(t *testing.T, ctx context.Context, db *orgs) {
}) })
} }
func orgsCountByUser(t *testing.T, ctx context.Context, db *orgs) { func orgsCountByUser(t *testing.T, ctx context.Context, db *orgsStore) {
// TODO: Use Orgs.Join to replace SQL hack when the method is available. // TODO: Use Orgs.Join to replace SQL hack when the method is available.
err := db.Exec(`INSERT INTO org_user (uid, org_id) VALUES (?, ?)`, 1, 1).Error err := db.Exec(`INSERT INTO org_user (uid, org_id) VALUES (?, ?)`, 1, 1).Error
require.NoError(t, err) require.NoError(t, err)

View File

@ -74,16 +74,16 @@ func ParseAccessMode(permission string) AccessMode {
} }
} }
var _ PermsStore = (*perms)(nil) var _ PermsStore = (*permsStore)(nil)
type perms struct { type permsStore struct {
*gorm.DB *gorm.DB
} }
// NewPermsStore returns a persistent interface for permissions with given // NewPermsStore returns a persistent interface for permissions with given
// database connection. // database connection.
func NewPermsStore(db *gorm.DB) PermsStore { func NewPermsStore(db *gorm.DB) PermsStore {
return &perms{DB: db} return &permsStore{DB: db}
} }
type AccessModeOptions struct { type AccessModeOptions struct {
@ -91,7 +91,7 @@ type AccessModeOptions struct {
Private bool // Whether the repository is private. Private bool // Whether the repository is private.
} }
func (db *perms) AccessMode(ctx context.Context, userID, repoID int64, opts AccessModeOptions) (mode AccessMode) { func (s *permsStore) AccessMode(ctx context.Context, userID, repoID int64, opts AccessModeOptions) (mode AccessMode) {
if repoID <= 0 { if repoID <= 0 {
return AccessModeNone return AccessModeNone
} }
@ -111,7 +111,7 @@ func (db *perms) AccessMode(ctx context.Context, userID, repoID int64, opts Acce
} }
access := new(Access) access := new(Access)
err := db.WithContext(ctx).Where("user_id = ? AND repo_id = ?", userID, repoID).First(access).Error err := s.WithContext(ctx).Where("user_id = ? AND repo_id = ?", userID, repoID).First(access).Error
if err != nil { if err != nil {
if err != gorm.ErrRecordNotFound { if err != gorm.ErrRecordNotFound {
log.Error("Failed to get access [user_id: %d, repo_id: %d]: %v", userID, repoID, err) log.Error("Failed to get access [user_id: %d, repo_id: %d]: %v", userID, repoID, err)
@ -121,11 +121,11 @@ func (db *perms) AccessMode(ctx context.Context, userID, repoID int64, opts Acce
return access.Mode return access.Mode
} }
func (db *perms) Authorize(ctx context.Context, userID, repoID int64, desired AccessMode, opts AccessModeOptions) bool { func (s *permsStore) Authorize(ctx context.Context, userID, repoID int64, desired AccessMode, opts AccessModeOptions) bool {
return desired <= db.AccessMode(ctx, userID, repoID, opts) return desired <= s.AccessMode(ctx, userID, repoID, opts)
} }
func (db *perms) SetRepoPerms(ctx context.Context, repoID int64, accessMap map[int64]AccessMode) error { func (s *permsStore) SetRepoPerms(ctx context.Context, repoID int64, accessMap map[int64]AccessMode) error {
records := make([]*Access, 0, len(accessMap)) records := make([]*Access, 0, len(accessMap))
for userID, mode := range accessMap { for userID, mode := range accessMap {
records = append(records, &Access{ records = append(records, &Access{
@ -135,7 +135,7 @@ func (db *perms) SetRepoPerms(ctx context.Context, repoID int64, accessMap map[i
}) })
} }
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
err := tx.Where("repo_id = ?", repoID).Delete(new(Access)).Error err := tx.Where("repo_id = ?", repoID).Delete(new(Access)).Error
if err != nil { if err != nil {
return err return err

View File

@ -19,13 +19,13 @@ func TestPerms(t *testing.T) {
t.Parallel() t.Parallel()
ctx := context.Background() ctx := context.Background()
db := &perms{ db := &permsStore{
DB: newTestDB(t, "perms"), DB: newTestDB(t, "permsStore"),
} }
for _, tc := range []struct { for _, tc := range []struct {
name string name string
test func(t *testing.T, ctx context.Context, db *perms) test func(t *testing.T, ctx context.Context, db *permsStore)
}{ }{
{"AccessMode", permsAccessMode}, {"AccessMode", permsAccessMode},
{"Authorize", permsAuthorize}, {"Authorize", permsAuthorize},
@ -44,7 +44,7 @@ func TestPerms(t *testing.T) {
} }
} }
func permsAccessMode(t *testing.T, ctx context.Context, db *perms) { func permsAccessMode(t *testing.T, ctx context.Context, db *permsStore) {
// Set up permissions // Set up permissions
err := db.SetRepoPerms(ctx, 1, err := db.SetRepoPerms(ctx, 1,
map[int64]AccessMode{ map[int64]AccessMode{
@ -155,7 +155,7 @@ func permsAccessMode(t *testing.T, ctx context.Context, db *perms) {
} }
} }
func permsAuthorize(t *testing.T, ctx context.Context, db *perms) { func permsAuthorize(t *testing.T, ctx context.Context, db *permsStore) {
// Set up permissions // Set up permissions
err := db.SetRepoPerms(ctx, 1, err := db.SetRepoPerms(ctx, 1,
map[int64]AccessMode{ map[int64]AccessMode{
@ -241,7 +241,7 @@ func permsAuthorize(t *testing.T, ctx context.Context, db *perms) {
} }
} }
func permsSetRepoPerms(t *testing.T, ctx context.Context, db *perms) { func permsSetRepoPerms(t *testing.T, ctx context.Context, db *permsStore) {
for _, update := range []struct { for _, update := range []struct {
repoID int64 repoID int64
accessMap map[int64]AccessMode accessMap map[int64]AccessMode

View File

@ -24,23 +24,23 @@ type PublicKeysStore interface {
var PublicKeys PublicKeysStore var PublicKeys PublicKeysStore
var _ PublicKeysStore = (*publicKeys)(nil) var _ PublicKeysStore = (*publicKeysStore)(nil)
type publicKeys struct { type publicKeysStore struct {
*gorm.DB *gorm.DB
} }
// NewPublicKeysStore returns a persistent interface for public keys with given // NewPublicKeysStore returns a persistent interface for public keys with given
// database connection. // database connection.
func NewPublicKeysStore(db *gorm.DB) PublicKeysStore { func NewPublicKeysStore(db *gorm.DB) PublicKeysStore {
return &publicKeys{DB: db} return &publicKeysStore{DB: db}
} }
func authorizedKeysPath() string { func authorizedKeysPath() string {
return filepath.Join(conf.SSH.RootPath, "authorized_keys") return filepath.Join(conf.SSH.RootPath, "authorized_keys")
} }
func (db *publicKeys) RewriteAuthorizedKeys() error { func (s *publicKeysStore) RewriteAuthorizedKeys() error {
sshOpLocker.Lock() sshOpLocker.Lock()
defer sshOpLocker.Unlock() defer sshOpLocker.Unlock()
@ -61,7 +61,7 @@ func (db *publicKeys) RewriteAuthorizedKeys() error {
// NOTE: More recently updated keys are more likely to be used more frequently, // NOTE: More recently updated keys are more likely to be used more frequently,
// putting them in the earlier lines could speed up the key lookup by SSHD. // putting them in the earlier lines could speed up the key lookup by SSHD.
rows, err := db.Model(&PublicKey{}).Order("updated_unix DESC").Rows() rows, err := s.Model(&PublicKey{}).Order("updated_unix DESC").Rows()
if err != nil { if err != nil {
return errors.Wrap(err, "iterate public keys") return errors.Wrap(err, "iterate public keys")
} }
@ -69,7 +69,7 @@ func (db *publicKeys) RewriteAuthorizedKeys() error {
for rows.Next() { for rows.Next() {
var key PublicKey var key PublicKey
err = db.ScanRows(rows, &key) err = s.ScanRows(rows, &key)
if err != nil { if err != nil {
return errors.Wrap(err, "scan rows") return errors.Wrap(err, "scan rows")
} }

View File

@ -24,13 +24,13 @@ func TestPublicKeys(t *testing.T) {
t.Parallel() t.Parallel()
ctx := context.Background() ctx := context.Background()
db := &publicKeys{ db := &publicKeysStore{
DB: newTestDB(t, "publicKeys"), DB: newTestDB(t, "publicKeysStore"),
} }
for _, tc := range []struct { for _, tc := range []struct {
name string name string
test func(t *testing.T, ctx context.Context, db *publicKeys) test func(t *testing.T, ctx context.Context, db *publicKeysStore)
}{ }{
{"RewriteAuthorizedKeys", publicKeysRewriteAuthorizedKeys}, {"RewriteAuthorizedKeys", publicKeysRewriteAuthorizedKeys},
} { } {
@ -47,7 +47,7 @@ func TestPublicKeys(t *testing.T) {
} }
} }
func publicKeysRewriteAuthorizedKeys(t *testing.T, ctx context.Context, db *publicKeys) { func publicKeysRewriteAuthorizedKeys(t *testing.T, ctx context.Context, db *publicKeysStore) {
// TODO: Use PublicKeys.Add to replace SQL hack when the method is available. // TODO: Use PublicKeys.Add to replace SQL hack when the method is available.
publicKey := &PublicKey{ publicKey := &PublicKey{
OwnerID: 1, OwnerID: 1,

View File

@ -119,16 +119,16 @@ func (r *Repository) APIFormat(owner *User, opts ...RepositoryAPIFormatOptions)
} }
} }
var _ ReposStore = (*repos)(nil) var _ ReposStore = (*reposStore)(nil)
type repos struct { type reposStore struct {
*gorm.DB *gorm.DB
} }
// NewReposStore returns a persistent interface for repositories with given // NewReposStore returns a persistent interface for repositories with given
// database connection. // database connection.
func NewReposStore(db *gorm.DB) ReposStore { func NewReposStore(db *gorm.DB) ReposStore {
return &repos{DB: db} return &reposStore{DB: db}
} }
type ErrRepoAlreadyExist struct { type ErrRepoAlreadyExist struct {
@ -157,13 +157,13 @@ type CreateRepoOptions struct {
ForkID int64 ForkID int64
} }
func (db *repos) Create(ctx context.Context, ownerID int64, opts CreateRepoOptions) (*Repository, error) { func (s *reposStore) Create(ctx context.Context, ownerID int64, opts CreateRepoOptions) (*Repository, error) {
err := isRepoNameAllowed(opts.Name) err := isRepoNameAllowed(opts.Name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
_, err = db.GetByName(ctx, ownerID, opts.Name) _, err = s.GetByName(ctx, ownerID, opts.Name)
if err == nil { if err == nil {
return nil, ErrRepoAlreadyExist{ return nil, ErrRepoAlreadyExist{
args: errutil.Args{ args: errutil.Args{
@ -189,7 +189,7 @@ func (db *repos) Create(ctx context.Context, ownerID int64, opts CreateRepoOptio
IsFork: opts.Fork, IsFork: opts.Fork,
ForkID: opts.ForkID, ForkID: opts.ForkID,
} }
return repo, db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { return repo, s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
err = tx.Create(repo).Error err = tx.Create(repo).Error
if err != nil { if err != nil {
return errors.Wrap(err, "create") return errors.Wrap(err, "create")
@ -203,7 +203,7 @@ func (db *repos) Create(ctx context.Context, ownerID int64, opts CreateRepoOptio
}) })
} }
func (db *repos) GetByCollaboratorID(ctx context.Context, collaboratorID int64, limit int, orderBy string) ([]*Repository, error) { func (s *reposStore) GetByCollaboratorID(ctx context.Context, collaboratorID int64, limit int, orderBy string) ([]*Repository, error) {
/* /*
Equivalent SQL for PostgreSQL: Equivalent SQL for PostgreSQL:
@ -214,7 +214,7 @@ func (db *repos) GetByCollaboratorID(ctx context.Context, collaboratorID int64,
LIMIT @limit LIMIT @limit
*/ */
var repos []*Repository var repos []*Repository
return repos, db.WithContext(ctx). return repos, s.WithContext(ctx).
Joins("JOIN access ON access.repo_id = repository.id AND access.user_id = ?", collaboratorID). Joins("JOIN access ON access.repo_id = repository.id AND access.user_id = ?", collaboratorID).
Where("access.mode >= ?", AccessModeRead). Where("access.mode >= ?", AccessModeRead).
Order(orderBy). Order(orderBy).
@ -223,7 +223,7 @@ func (db *repos) GetByCollaboratorID(ctx context.Context, collaboratorID int64,
Error Error
} }
func (db *repos) GetByCollaboratorIDWithAccessMode(ctx context.Context, collaboratorID int64) (map[*Repository]AccessMode, error) { func (s *reposStore) GetByCollaboratorIDWithAccessMode(ctx context.Context, collaboratorID int64) (map[*Repository]AccessMode, error) {
/* /*
Equivalent SQL for PostgreSQL: Equivalent SQL for PostgreSQL:
@ -238,7 +238,7 @@ func (db *repos) GetByCollaboratorIDWithAccessMode(ctx context.Context, collabor
*Repository *Repository
Mode AccessMode Mode AccessMode
} }
err := db.WithContext(ctx). err := s.WithContext(ctx).
Select("repository.*", "access.mode"). Select("repository.*", "access.mode").
Table("repository"). Table("repository").
Joins("JOIN access ON access.repo_id = repository.id AND access.user_id = ?", collaboratorID). Joins("JOIN access ON access.repo_id = repository.id AND access.user_id = ?", collaboratorID).
@ -275,9 +275,9 @@ func (ErrRepoNotExist) NotFound() bool {
return true return true
} }
func (db *repos) GetByID(ctx context.Context, id int64) (*Repository, error) { func (s *reposStore) GetByID(ctx context.Context, id int64) (*Repository, error) {
repo := new(Repository) repo := new(Repository)
err := db.WithContext(ctx).Where("id = ?", id).First(repo).Error err := s.WithContext(ctx).Where("id = ?", id).First(repo).Error
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if err == gorm.ErrRecordNotFound {
return nil, ErrRepoNotExist{errutil.Args{"repoID": id}} return nil, ErrRepoNotExist{errutil.Args{"repoID": id}}
@ -287,9 +287,9 @@ func (db *repos) GetByID(ctx context.Context, id int64) (*Repository, error) {
return repo, nil return repo, nil
} }
func (db *repos) GetByName(ctx context.Context, ownerID int64, name string) (*Repository, error) { func (s *reposStore) GetByName(ctx context.Context, ownerID int64, name string) (*Repository, error) {
repo := new(Repository) repo := new(Repository)
err := db.WithContext(ctx). err := s.WithContext(ctx).
Where("owner_id = ? AND lower_name = ?", ownerID, strings.ToLower(name)). Where("owner_id = ? AND lower_name = ?", ownerID, strings.ToLower(name)).
First(repo). First(repo).
Error Error
@ -307,7 +307,7 @@ func (db *repos) GetByName(ctx context.Context, ownerID int64, name string) (*Re
return repo, nil return repo, nil
} }
func (db *repos) recountStars(tx *gorm.DB, userID, repoID int64) error { func (s *reposStore) recountStars(tx *gorm.DB, userID, repoID int64) error {
/* /*
Equivalent SQL for PostgreSQL: Equivalent SQL for PostgreSQL:
@ -350,40 +350,40 @@ func (db *repos) recountStars(tx *gorm.DB, userID, repoID int64) error {
return nil return nil
} }
func (db *repos) Star(ctx context.Context, userID, repoID int64) error { func (s *reposStore) Star(ctx context.Context, userID, repoID int64) error {
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
s := &Star{ star := &Star{
UserID: userID, UserID: userID,
RepoID: repoID, RepoID: repoID,
} }
result := tx.FirstOrCreate(s, s) result := tx.FirstOrCreate(star, star)
if result.Error != nil { if result.Error != nil {
return errors.Wrap(result.Error, "upsert") return errors.Wrap(result.Error, "upsert")
} else if result.RowsAffected <= 0 { } else if result.RowsAffected <= 0 {
return nil // Relation already exists return nil // Relation already exists
} }
return db.recountStars(tx, userID, repoID) return s.recountStars(tx, userID, repoID)
}) })
} }
func (db *repos) Touch(ctx context.Context, id int64) error { func (s *reposStore) Touch(ctx context.Context, id int64) error {
return db.WithContext(ctx). return s.WithContext(ctx).
Model(new(Repository)). Model(new(Repository)).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
"is_bare": false, "is_bare": false,
"updated_unix": db.NowFunc().Unix(), "updated_unix": s.NowFunc().Unix(),
}). }).
Error Error
} }
func (db *repos) ListWatches(ctx context.Context, repoID int64) ([]*Watch, error) { func (s *reposStore) ListWatches(ctx context.Context, repoID int64) ([]*Watch, error) {
var watches []*Watch var watches []*Watch
return watches, db.WithContext(ctx).Where("repo_id = ?", repoID).Find(&watches).Error return watches, s.WithContext(ctx).Where("repo_id = ?", repoID).Find(&watches).Error
} }
func (db *repos) recountWatches(tx *gorm.DB, repoID int64) error { func (s *reposStore) recountWatches(tx *gorm.DB, repoID int64) error {
/* /*
Equivalent SQL for PostgreSQL: Equivalent SQL for PostgreSQL:
@ -402,8 +402,8 @@ func (db *repos) recountWatches(tx *gorm.DB, repoID int64) error {
Error Error
} }
func (db *repos) Watch(ctx context.Context, userID, repoID int64) error { func (s *reposStore) Watch(ctx context.Context, userID, repoID int64) error {
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
w := &Watch{ w := &Watch{
UserID: userID, UserID: userID,
RepoID: repoID, RepoID: repoID,
@ -415,12 +415,12 @@ func (db *repos) Watch(ctx context.Context, userID, repoID int64) error {
return nil // Relation already exists return nil // Relation already exists
} }
return db.recountWatches(tx, repoID) return s.recountWatches(tx, repoID)
}) })
} }
func (db *repos) HasForkedBy(ctx context.Context, repoID, userID int64) bool { func (s *reposStore) HasForkedBy(ctx context.Context, repoID, userID int64) bool {
var count int64 var count int64
db.WithContext(ctx).Model(new(Repository)).Where("owner_id = ? AND fork_id = ?", userID, repoID).Count(&count) s.WithContext(ctx).Model(new(Repository)).Where("owner_id = ? AND fork_id = ?", userID, repoID).Count(&count)
return count > 0 return count > 0
} }

View File

@ -85,13 +85,13 @@ func TestRepos(t *testing.T) {
t.Parallel() t.Parallel()
ctx := context.Background() ctx := context.Background()
db := &repos{ db := &reposStore{
DB: newTestDB(t, "repos"), DB: newTestDB(t, "repos"),
} }
for _, tc := range []struct { for _, tc := range []struct {
name string name string
test func(t *testing.T, ctx context.Context, db *repos) test func(t *testing.T, ctx context.Context, db *reposStore)
}{ }{
{"Create", reposCreate}, {"Create", reposCreate},
{"GetByCollaboratorID", reposGetByCollaboratorID}, {"GetByCollaboratorID", reposGetByCollaboratorID},
@ -117,7 +117,7 @@ func TestRepos(t *testing.T) {
} }
} }
func reposCreate(t *testing.T, ctx context.Context, db *repos) { func reposCreate(t *testing.T, ctx context.Context, db *reposStore) {
t.Run("name not allowed", func(t *testing.T) { t.Run("name not allowed", func(t *testing.T) {
_, err := db.Create(ctx, _, err := db.Create(ctx,
1, 1,
@ -159,7 +159,7 @@ func reposCreate(t *testing.T, ctx context.Context, db *repos) {
assert.Equal(t, 1, repo.NumWatches) // The owner is watching the repo by default. assert.Equal(t, 1, repo.NumWatches) // The owner is watching the repo by default.
} }
func reposGetByCollaboratorID(t *testing.T, ctx context.Context, db *repos) { func reposGetByCollaboratorID(t *testing.T, ctx context.Context, db *reposStore) {
repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"}) repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"})
require.NoError(t, err) require.NoError(t, err)
repo2, err := db.Create(ctx, 2, CreateRepoOptions{Name: "repo2"}) repo2, err := db.Create(ctx, 2, CreateRepoOptions{Name: "repo2"})
@ -185,7 +185,7 @@ func reposGetByCollaboratorID(t *testing.T, ctx context.Context, db *repos) {
}) })
} }
func reposGetByCollaboratorIDWithAccessMode(t *testing.T, ctx context.Context, db *repos) { func reposGetByCollaboratorIDWithAccessMode(t *testing.T, ctx context.Context, db *reposStore) {
repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"}) repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"})
require.NoError(t, err) require.NoError(t, err)
repo2, err := db.Create(ctx, 2, CreateRepoOptions{Name: "repo2"}) repo2, err := db.Create(ctx, 2, CreateRepoOptions{Name: "repo2"})
@ -213,7 +213,7 @@ func reposGetByCollaboratorIDWithAccessMode(t *testing.T, ctx context.Context, d
assert.Equal(t, AccessModeAdmin, accessModes[repo2.ID]) assert.Equal(t, AccessModeAdmin, accessModes[repo2.ID])
} }
func reposGetByID(t *testing.T, ctx context.Context, db *repos) { func reposGetByID(t *testing.T, ctx context.Context, db *reposStore) {
repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"}) repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"})
require.NoError(t, err) require.NoError(t, err)
@ -226,7 +226,7 @@ func reposGetByID(t *testing.T, ctx context.Context, db *repos) {
assert.Equal(t, wantErr, err) assert.Equal(t, wantErr, err)
} }
func reposGetByName(t *testing.T, ctx context.Context, db *repos) { func reposGetByName(t *testing.T, ctx context.Context, db *reposStore) {
repo, err := db.Create(ctx, 1, repo, err := db.Create(ctx, 1,
CreateRepoOptions{ CreateRepoOptions{
Name: "repo1", Name: "repo1",
@ -242,7 +242,7 @@ func reposGetByName(t *testing.T, ctx context.Context, db *repos) {
assert.Equal(t, wantErr, err) assert.Equal(t, wantErr, err)
} }
func reposStar(t *testing.T, ctx context.Context, db *repos) { func reposStar(t *testing.T, ctx context.Context, db *reposStore) {
repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"}) repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"})
require.NoError(t, err) require.NoError(t, err)
usersStore := NewUsersStore(db.DB) usersStore := NewUsersStore(db.DB)
@ -261,7 +261,7 @@ func reposStar(t *testing.T, ctx context.Context, db *repos) {
assert.Equal(t, 1, alice.NumStars) assert.Equal(t, 1, alice.NumStars)
} }
func reposTouch(t *testing.T, ctx context.Context, db *repos) { func reposTouch(t *testing.T, ctx context.Context, db *reposStore) {
repo, err := db.Create(ctx, 1, repo, err := db.Create(ctx, 1,
CreateRepoOptions{ CreateRepoOptions{
Name: "repo1", Name: "repo1",
@ -287,7 +287,7 @@ func reposTouch(t *testing.T, ctx context.Context, db *repos) {
assert.False(t, got.IsBare) assert.False(t, got.IsBare)
} }
func reposListWatches(t *testing.T, ctx context.Context, db *repos) { func reposListWatches(t *testing.T, ctx context.Context, db *reposStore) {
err := db.Watch(ctx, 1, 1) err := db.Watch(ctx, 1, 1)
require.NoError(t, err) require.NoError(t, err)
err = db.Watch(ctx, 2, 1) err = db.Watch(ctx, 2, 1)
@ -308,7 +308,7 @@ func reposListWatches(t *testing.T, ctx context.Context, db *repos) {
assert.Equal(t, want, got) assert.Equal(t, want, got)
} }
func reposWatch(t *testing.T, ctx context.Context, db *repos) { func reposWatch(t *testing.T, ctx context.Context, db *reposStore) {
reposStore := NewReposStore(db.DB) reposStore := NewReposStore(db.DB)
repo1, err := reposStore.Create(ctx, 1, CreateRepoOptions{Name: "repo1"}) repo1, err := reposStore.Create(ctx, 1, CreateRepoOptions{Name: "repo1"})
require.NoError(t, err) require.NoError(t, err)
@ -325,7 +325,7 @@ func reposWatch(t *testing.T, ctx context.Context, db *repos) {
assert.Equal(t, 2, repo1.NumWatches) // The owner is watching the repo by default. assert.Equal(t, 2, repo1.NumWatches) // The owner is watching the repo by default.
} }
func reposHasForkedBy(t *testing.T, ctx context.Context, db *repos) { func reposHasForkedBy(t *testing.T, ctx context.Context, db *reposStore) {
has := db.HasForkedBy(ctx, 1, 2) has := db.HasForkedBy(ctx, 1, 2)
assert.False(t, has) assert.False(t, has)

View File

@ -50,13 +50,13 @@ func (t *TwoFactor) AfterFind(_ *gorm.DB) error {
return nil return nil
} }
var _ TwoFactorsStore = (*twoFactors)(nil) var _ TwoFactorsStore = (*twoFactorsStore)(nil)
type twoFactors struct { type twoFactorsStore struct {
*gorm.DB *gorm.DB
} }
func (db *twoFactors) Create(ctx context.Context, userID int64, key, secret string) error { func (s *twoFactorsStore) Create(ctx context.Context, userID int64, key, secret string) error {
encrypted, err := cryptoutil.AESGCMEncrypt(cryptoutil.MD5Bytes(key), []byte(secret)) encrypted, err := cryptoutil.AESGCMEncrypt(cryptoutil.MD5Bytes(key), []byte(secret))
if err != nil { if err != nil {
return errors.Wrap(err, "encrypt secret") return errors.Wrap(err, "encrypt secret")
@ -71,7 +71,7 @@ func (db *twoFactors) Create(ctx context.Context, userID int64, key, secret stri
return errors.Wrap(err, "generate recovery codes") return errors.Wrap(err, "generate recovery codes")
} }
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
err := tx.Create(tf).Error err := tx.Create(tf).Error
if err != nil { if err != nil {
return err return err
@ -100,9 +100,9 @@ func (ErrTwoFactorNotFound) NotFound() bool {
return true return true
} }
func (db *twoFactors) GetByUserID(ctx context.Context, userID int64) (*TwoFactor, error) { func (s *twoFactorsStore) GetByUserID(ctx context.Context, userID int64) (*TwoFactor, error) {
tf := new(TwoFactor) tf := new(TwoFactor)
err := db.WithContext(ctx).Where("user_id = ?", userID).First(tf).Error err := s.WithContext(ctx).Where("user_id = ?", userID).First(tf).Error
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if err == gorm.ErrRecordNotFound {
return nil, ErrTwoFactorNotFound{args: errutil.Args{"userID": userID}} return nil, ErrTwoFactorNotFound{args: errutil.Args{"userID": userID}}
@ -112,9 +112,9 @@ func (db *twoFactors) GetByUserID(ctx context.Context, userID int64) (*TwoFactor
return tf, nil return tf, nil
} }
func (db *twoFactors) IsEnabled(ctx context.Context, userID int64) bool { func (s *twoFactorsStore) IsEnabled(ctx context.Context, userID int64) bool {
var count int64 var count int64
err := db.WithContext(ctx).Model(new(TwoFactor)).Where("user_id = ?", userID).Count(&count).Error err := s.WithContext(ctx).Model(new(TwoFactor)).Where("user_id = ?", userID).Count(&count).Error
if err != nil { if err != nil {
log.Error("Failed to count two factors [user_id: %d]: %v", userID, err) log.Error("Failed to count two factors [user_id: %d]: %v", userID, err)
} }

View File

@ -67,13 +67,13 @@ func TestTwoFactors(t *testing.T) {
t.Parallel() t.Parallel()
ctx := context.Background() ctx := context.Background()
db := &twoFactors{ db := &twoFactorsStore{
DB: newTestDB(t, "twoFactors"), DB: newTestDB(t, "twoFactorsStore"),
} }
for _, tc := range []struct { for _, tc := range []struct {
name string name string
test func(t *testing.T, ctx context.Context, db *twoFactors) test func(t *testing.T, ctx context.Context, db *twoFactorsStore)
}{ }{
{"Create", twoFactorsCreate}, {"Create", twoFactorsCreate},
{"GetByUserID", twoFactorsGetByUserID}, {"GetByUserID", twoFactorsGetByUserID},
@ -92,7 +92,7 @@ func TestTwoFactors(t *testing.T) {
} }
} }
func twoFactorsCreate(t *testing.T, ctx context.Context, db *twoFactors) { func twoFactorsCreate(t *testing.T, ctx context.Context, db *twoFactorsStore) {
// Create a 2FA token // Create a 2FA token
err := db.Create(ctx, 1, "secure-key", "secure-secret") err := db.Create(ctx, 1, "secure-key", "secure-secret")
require.NoError(t, err) require.NoError(t, err)
@ -109,7 +109,7 @@ func twoFactorsCreate(t *testing.T, ctx context.Context, db *twoFactors) {
assert.Equal(t, int64(10), count) assert.Equal(t, int64(10), count)
} }
func twoFactorsGetByUserID(t *testing.T, ctx context.Context, db *twoFactors) { func twoFactorsGetByUserID(t *testing.T, ctx context.Context, db *twoFactorsStore) {
// Create a 2FA token for user 1 // Create a 2FA token for user 1
err := db.Create(ctx, 1, "secure-key", "secure-secret") err := db.Create(ctx, 1, "secure-key", "secure-secret")
require.NoError(t, err) require.NoError(t, err)
@ -124,7 +124,7 @@ func twoFactorsGetByUserID(t *testing.T, ctx context.Context, db *twoFactors) {
assert.Equal(t, wantErr, err) assert.Equal(t, wantErr, err)
} }
func twoFactorsIsEnabled(t *testing.T, ctx context.Context, db *twoFactors) { func twoFactorsIsEnabled(t *testing.T, ctx context.Context, db *twoFactorsStore) {
// Create a 2FA token for user 1 // Create a 2FA token for user 1
err := db.Create(ctx, 1, "secure-key", "secure-secret") err := db.Create(ctx, 1, "secure-key", "secure-secret")
require.NoError(t, err) require.NoError(t, err)

View File

@ -146,16 +146,16 @@ type UsersStore interface {
var Users UsersStore var Users UsersStore
var _ UsersStore = (*users)(nil) var _ UsersStore = (*usersStore)(nil)
type users struct { type usersStore struct {
*gorm.DB *gorm.DB
} }
// NewUsersStore returns a persistent interface for users with given database // NewUsersStore returns a persistent interface for users with given database
// connection. // connection.
func NewUsersStore(db *gorm.DB) UsersStore { func NewUsersStore(db *gorm.DB) UsersStore {
return &users{DB: db} return &usersStore{DB: db}
} }
type ErrLoginSourceMismatch struct { type ErrLoginSourceMismatch struct {
@ -173,10 +173,10 @@ func (err ErrLoginSourceMismatch) Error() string {
return fmt.Sprintf("login source mismatch: %v", err.args) return fmt.Sprintf("login source mismatch: %v", err.args)
} }
func (db *users) Authenticate(ctx context.Context, login, password string, loginSourceID int64) (*User, error) { func (s *usersStore) Authenticate(ctx context.Context, login, password string, loginSourceID int64) (*User, error) {
login = strings.ToLower(login) login = strings.ToLower(login)
query := db.WithContext(ctx) query := s.WithContext(ctx)
if strings.Contains(login, "@") { if strings.Contains(login, "@") {
query = query.Where("email = ?", login) query = query.Where("email = ?", login)
} else { } else {
@ -244,7 +244,7 @@ func (db *users) Authenticate(ctx context.Context, login, password string, login
return nil, fmt.Errorf("invalid pattern for attribute 'username' [%s]: must be valid alpha or numeric or dash(-_) or dot characters", extAccount.Name) return nil, fmt.Errorf("invalid pattern for attribute 'username' [%s]: must be valid alpha or numeric or dash(-_) or dot characters", extAccount.Name)
} }
return db.Create(ctx, extAccount.Name, extAccount.Email, return s.Create(ctx, extAccount.Name, extAccount.Email,
CreateUserOptions{ CreateUserOptions{
FullName: extAccount.FullName, FullName: extAccount.FullName,
LoginSource: authSourceID, LoginSource: authSourceID,
@ -257,13 +257,13 @@ func (db *users) Authenticate(ctx context.Context, login, password string, login
) )
} }
func (db *users) ChangeUsername(ctx context.Context, userID int64, newUsername string) error { func (s *usersStore) ChangeUsername(ctx context.Context, userID int64, newUsername string) error {
err := isUsernameAllowed(newUsername) err := isUsernameAllowed(newUsername)
if err != nil { if err != nil {
return err return err
} }
if db.IsUsernameUsed(ctx, newUsername, userID) { if s.IsUsernameUsed(ctx, newUsername, userID) {
return ErrUserAlreadyExist{ return ErrUserAlreadyExist{
args: errutil.Args{ args: errutil.Args{
"name": newUsername, "name": newUsername,
@ -271,12 +271,12 @@ func (db *users) ChangeUsername(ctx context.Context, userID int64, newUsername s
} }
} }
user, err := db.GetByID(ctx, userID) user, err := s.GetByID(ctx, userID)
if err != nil { if err != nil {
return errors.Wrap(err, "get user") return errors.Wrap(err, "get user")
} }
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
err := tx.Model(&User{}). err := tx.Model(&User{}).
Where("id = ?", user.ID). Where("id = ?", user.ID).
Updates(map[string]any{ Updates(map[string]any{
@ -338,9 +338,9 @@ func (db *users) ChangeUsername(ctx context.Context, userID int64, newUsername s
}) })
} }
func (db *users) Count(ctx context.Context) int64 { func (s *usersStore) Count(ctx context.Context) int64 {
var count int64 var count int64
db.WithContext(ctx).Model(&User{}).Where("type = ?", UserTypeIndividual).Count(&count) s.WithContext(ctx).Model(&User{}).Where("type = ?", UserTypeIndividual).Count(&count)
return count return count
} }
@ -393,13 +393,13 @@ func (err ErrEmailAlreadyUsed) Error() string {
return fmt.Sprintf("email has been used: %v", err.args) return fmt.Sprintf("email has been used: %v", err.args)
} }
func (db *users) Create(ctx context.Context, username, email string, opts CreateUserOptions) (*User, error) { func (s *usersStore) Create(ctx context.Context, username, email string, opts CreateUserOptions) (*User, error) {
err := isUsernameAllowed(username) err := isUsernameAllowed(username)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if db.IsUsernameUsed(ctx, username, 0) { if s.IsUsernameUsed(ctx, username, 0) {
return nil, ErrUserAlreadyExist{ return nil, ErrUserAlreadyExist{
args: errutil.Args{ args: errutil.Args{
"name": username, "name": username,
@ -408,7 +408,7 @@ func (db *users) Create(ctx context.Context, username, email string, opts Create
} }
email = strings.ToLower(strings.TrimSpace(email)) email = strings.ToLower(strings.TrimSpace(email))
_, err = db.GetByEmail(ctx, email) _, err = s.GetByEmail(ctx, email)
if err == nil { if err == nil {
return nil, ErrEmailAlreadyUsed{ return nil, ErrEmailAlreadyUsed{
args: errutil.Args{ args: errutil.Args{
@ -446,17 +446,17 @@ func (db *users) Create(ctx context.Context, username, email string, opts Create
} }
user.Password = userutil.EncodePassword(user.Password, user.Salt) user.Password = userutil.EncodePassword(user.Password, user.Salt)
return user, db.WithContext(ctx).Create(user).Error return user, s.WithContext(ctx).Create(user).Error
} }
func (db *users) DeleteCustomAvatar(ctx context.Context, userID int64) error { func (s *usersStore) DeleteCustomAvatar(ctx context.Context, userID int64) error {
_ = os.Remove(userutil.CustomAvatarPath(userID)) _ = os.Remove(userutil.CustomAvatarPath(userID))
return db.WithContext(ctx). return s.WithContext(ctx).
Model(&User{}). Model(&User{}).
Where("id = ?", userID). Where("id = ?", userID).
Updates(map[string]any{ Updates(map[string]any{
"use_custom_avatar": false, "use_custom_avatar": false,
"updated_unix": db.NowFunc().Unix(), "updated_unix": s.NowFunc().Unix(),
}). }).
Error Error
} }
@ -491,8 +491,8 @@ func (err ErrUserHasOrgs) Error() string {
return fmt.Sprintf("user still has organization membership: %v", err.args) return fmt.Sprintf("user still has organization membership: %v", err.args)
} }
func (db *users) DeleteByID(ctx context.Context, userID int64, skipRewriteAuthorizedKeys bool) error { func (s *usersStore) DeleteByID(ctx context.Context, userID int64, skipRewriteAuthorizedKeys bool) error {
user, err := db.GetByID(ctx, userID) user, err := s.GetByID(ctx, userID)
if err != nil { if err != nil {
if IsErrUserNotExist(err) { if IsErrUserNotExist(err) {
return nil return nil
@ -503,14 +503,14 @@ func (db *users) DeleteByID(ctx context.Context, userID int64, skipRewriteAuthor
// Double check the user is not a direct owner of any repository and not a // Double check the user is not a direct owner of any repository and not a
// member of any organization. // member of any organization.
var count int64 var count int64
err = db.WithContext(ctx).Model(&Repository{}).Where("owner_id = ?", userID).Count(&count).Error err = s.WithContext(ctx).Model(&Repository{}).Where("owner_id = ?", userID).Count(&count).Error
if err != nil { if err != nil {
return errors.Wrap(err, "count repositories") return errors.Wrap(err, "count repositories")
} else if count > 0 { } else if count > 0 {
return ErrUserOwnRepos{args: errutil.Args{"userID": userID}} return ErrUserOwnRepos{args: errutil.Args{"userID": userID}}
} }
err = db.WithContext(ctx).Model(&OrgUser{}).Where("uid = ?", userID).Count(&count).Error err = s.WithContext(ctx).Model(&OrgUser{}).Where("uid = ?", userID).Count(&count).Error
if err != nil { if err != nil {
return errors.Wrap(err, "count organization membership") return errors.Wrap(err, "count organization membership")
} else if count > 0 { } else if count > 0 {
@ -518,7 +518,7 @@ func (db *users) DeleteByID(ctx context.Context, userID int64, skipRewriteAuthor
} }
needsRewriteAuthorizedKeys := false needsRewriteAuthorizedKeys := false
err = db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { err = s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
/* /*
Equivalent SQL for PostgreSQL: Equivalent SQL for PostgreSQL:
@ -645,7 +645,7 @@ func (db *users) DeleteByID(ctx context.Context, userID int64, skipRewriteAuthor
_ = os.Remove(userutil.CustomAvatarPath(userID)) _ = os.Remove(userutil.CustomAvatarPath(userID))
if needsRewriteAuthorizedKeys { if needsRewriteAuthorizedKeys {
err = NewPublicKeysStore(db.DB).RewriteAuthorizedKeys() err = NewPublicKeysStore(s.DB).RewriteAuthorizedKeys()
if err != nil { if err != nil {
return errors.Wrap(err, `rewrite "authorized_keys" file`) return errors.Wrap(err, `rewrite "authorized_keys" file`)
} }
@ -655,15 +655,15 @@ func (db *users) DeleteByID(ctx context.Context, userID int64, skipRewriteAuthor
// NOTE: We do not take context.Context here because this operation in practice // NOTE: We do not take context.Context here because this operation in practice
// could much longer than the general request timeout (e.g. one minute). // could much longer than the general request timeout (e.g. one minute).
func (db *users) DeleteInactivated() error { func (s *usersStore) DeleteInactivated() error {
var userIDs []int64 var userIDs []int64
err := db.Model(&User{}).Where("is_active = ?", false).Pluck("id", &userIDs).Error err := s.Model(&User{}).Where("is_active = ?", false).Pluck("id", &userIDs).Error
if err != nil { if err != nil {
return errors.Wrap(err, "get inactivated user IDs") return errors.Wrap(err, "get inactivated user IDs")
} }
for _, userID := range userIDs { for _, userID := range userIDs {
err = db.DeleteByID(context.Background(), userID, true) err = s.DeleteByID(context.Background(), userID, true)
if err != nil { if err != nil {
// Skip users that may had set to inactivated by admins. // Skip users that may had set to inactivated by admins.
if IsErrUserOwnRepos(err) || IsErrUserHasOrgs(err) { if IsErrUserOwnRepos(err) || IsErrUserHasOrgs(err) {
@ -672,14 +672,14 @@ func (db *users) DeleteInactivated() error {
return errors.Wrapf(err, "delete user with ID %d", userID) return errors.Wrapf(err, "delete user with ID %d", userID)
} }
} }
err = NewPublicKeysStore(db.DB).RewriteAuthorizedKeys() err = NewPublicKeysStore(s.DB).RewriteAuthorizedKeys()
if err != nil { if err != nil {
return errors.Wrap(err, `rewrite "authorized_keys" file`) return errors.Wrap(err, `rewrite "authorized_keys" file`)
} }
return nil return nil
} }
func (*users) recountFollows(tx *gorm.DB, userID, followID int64) error { func (*usersStore) recountFollows(tx *gorm.DB, userID, followID int64) error {
/* /*
Equivalent SQL for PostgreSQL: Equivalent SQL for PostgreSQL:
@ -722,12 +722,12 @@ func (*users) recountFollows(tx *gorm.DB, userID, followID int64) error {
return nil return nil
} }
func (db *users) Follow(ctx context.Context, userID, followID int64) error { func (s *usersStore) Follow(ctx context.Context, userID, followID int64) error {
if userID == followID { if userID == followID {
return nil return nil
} }
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
f := &Follow{ f := &Follow{
UserID: userID, UserID: userID,
FollowID: followID, FollowID: followID,
@ -739,26 +739,26 @@ func (db *users) Follow(ctx context.Context, userID, followID int64) error {
return nil // Relation already exists return nil // Relation already exists
} }
return db.recountFollows(tx, userID, followID) return s.recountFollows(tx, userID, followID)
}) })
} }
func (db *users) Unfollow(ctx context.Context, userID, followID int64) error { func (s *usersStore) Unfollow(ctx context.Context, userID, followID int64) error {
if userID == followID { if userID == followID {
return nil return nil
} }
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
err := tx.Where("user_id = ? AND follow_id = ?", userID, followID).Delete(&Follow{}).Error err := tx.Where("user_id = ? AND follow_id = ?", userID, followID).Delete(&Follow{}).Error
if err != nil { if err != nil {
return errors.Wrap(err, "delete") return errors.Wrap(err, "delete")
} }
return db.recountFollows(tx, userID, followID) return s.recountFollows(tx, userID, followID)
}) })
} }
func (db *users) IsFollowing(ctx context.Context, userID, followID int64) bool { func (s *usersStore) IsFollowing(ctx context.Context, userID, followID int64) bool {
return db.WithContext(ctx).Where("user_id = ? AND follow_id = ?", userID, followID).First(&Follow{}).Error == nil return s.WithContext(ctx).Where("user_id = ? AND follow_id = ?", userID, followID).First(&Follow{}).Error == nil
} }
var _ errutil.NotFound = (*ErrUserNotExist)(nil) var _ errutil.NotFound = (*ErrUserNotExist)(nil)
@ -782,7 +782,7 @@ func (ErrUserNotExist) NotFound() bool {
return true return true
} }
func (db *users) GetByEmail(ctx context.Context, email string) (*User, error) { func (s *usersStore) GetByEmail(ctx context.Context, email string) (*User, error) {
if email == "" { if email == "" {
return nil, ErrUserNotExist{args: errutil.Args{"email": email}} return nil, ErrUserNotExist{args: errutil.Args{"email": email}}
} }
@ -801,10 +801,10 @@ func (db *users) GetByEmail(ctx context.Context, email string) (*User, error) {
) )
*/ */
user := new(User) user := new(User)
err := db.WithContext(ctx). err := s.WithContext(ctx).
Joins(dbutil.Quote("LEFT JOIN email_address ON email_address.uid = %s.id", "user"), true). Joins(dbutil.Quote("LEFT JOIN email_address ON email_address.uid = %s.id", "user"), true).
Where(dbutil.Quote("%s.type = ?", "user"), UserTypeIndividual). Where(dbutil.Quote("%s.type = ?", "user"), UserTypeIndividual).
Where(db. Where(s.
Where(dbutil.Quote("%[1]s.email = ? AND %[1]s.is_active = ?", "user"), email, true). Where(dbutil.Quote("%[1]s.email = ? AND %[1]s.is_active = ?", "user"), email, true).
Or("email_address.email = ? AND email_address.is_activated = ?", email, true), Or("email_address.email = ? AND email_address.is_activated = ?", email, true),
). ).
@ -819,9 +819,9 @@ func (db *users) GetByEmail(ctx context.Context, email string) (*User, error) {
return user, nil return user, nil
} }
func (db *users) GetByID(ctx context.Context, id int64) (*User, error) { func (s *usersStore) GetByID(ctx context.Context, id int64) (*User, error) {
user := new(User) user := new(User)
err := db.WithContext(ctx).Where("id = ?", id).First(user).Error err := s.WithContext(ctx).Where("id = ?", id).First(user).Error
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if err == gorm.ErrRecordNotFound {
return nil, ErrUserNotExist{args: errutil.Args{"userID": id}} return nil, ErrUserNotExist{args: errutil.Args{"userID": id}}
@ -831,9 +831,9 @@ func (db *users) GetByID(ctx context.Context, id int64) (*User, error) {
return user, nil return user, nil
} }
func (db *users) GetByUsername(ctx context.Context, username string) (*User, error) { func (s *usersStore) GetByUsername(ctx context.Context, username string) (*User, error) {
user := new(User) user := new(User)
err := db.WithContext(ctx).Where("lower_name = ?", strings.ToLower(username)).First(user).Error err := s.WithContext(ctx).Where("lower_name = ?", strings.ToLower(username)).First(user).Error
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if err == gorm.ErrRecordNotFound {
return nil, ErrUserNotExist{args: errutil.Args{"name": username}} return nil, ErrUserNotExist{args: errutil.Args{"name": username}}
@ -843,9 +843,9 @@ func (db *users) GetByUsername(ctx context.Context, username string) (*User, err
return user, nil return user, nil
} }
func (db *users) GetByKeyID(ctx context.Context, keyID int64) (*User, error) { func (s *usersStore) GetByKeyID(ctx context.Context, keyID int64) (*User, error) {
user := new(User) user := new(User)
err := db.WithContext(ctx). err := s.WithContext(ctx).
Joins(dbutil.Quote("JOIN public_key ON public_key.owner_id = %s.id", "user")). Joins(dbutil.Quote("JOIN public_key ON public_key.owner_id = %s.id", "user")).
Where("public_key.id = ?", keyID). Where("public_key.id = ?", keyID).
First(user). First(user).
@ -859,29 +859,29 @@ func (db *users) GetByKeyID(ctx context.Context, keyID int64) (*User, error) {
return user, nil return user, nil
} }
func (db *users) GetMailableEmailsByUsernames(ctx context.Context, usernames []string) ([]string, error) { func (s *usersStore) GetMailableEmailsByUsernames(ctx context.Context, usernames []string) ([]string, error) {
emails := make([]string, 0, len(usernames)) emails := make([]string, 0, len(usernames))
return emails, db.WithContext(ctx). return emails, s.WithContext(ctx).
Model(&User{}). Model(&User{}).
Select("email"). Select("email").
Where("lower_name IN (?) AND is_active = ?", usernames, true). Where("lower_name IN (?) AND is_active = ?", usernames, true).
Find(&emails).Error Find(&emails).Error
} }
func (db *users) IsUsernameUsed(ctx context.Context, username string, excludeUserId int64) bool { func (s *usersStore) IsUsernameUsed(ctx context.Context, username string, excludeUserId int64) bool {
if username == "" { if username == "" {
return false return false
} }
return db.WithContext(ctx). return s.WithContext(ctx).
Select("id"). Select("id").
Where("lower_name = ? AND id != ?", strings.ToLower(username), excludeUserId). Where("lower_name = ? AND id != ?", strings.ToLower(username), excludeUserId).
First(&User{}). First(&User{}).
Error != gorm.ErrRecordNotFound Error != gorm.ErrRecordNotFound
} }
func (db *users) List(ctx context.Context, page, pageSize int) ([]*User, error) { func (s *usersStore) List(ctx context.Context, page, pageSize int) ([]*User, error) {
users := make([]*User, 0, pageSize) users := make([]*User, 0, pageSize)
return users, db.WithContext(ctx). return users, s.WithContext(ctx).
Where("type = ?", UserTypeIndividual). Where("type = ?", UserTypeIndividual).
Limit(pageSize).Offset((page - 1) * pageSize). Limit(pageSize).Offset((page - 1) * pageSize).
Order("id ASC"). Order("id ASC").
@ -889,7 +889,7 @@ func (db *users) List(ctx context.Context, page, pageSize int) ([]*User, error)
Error Error
} }
func (db *users) ListFollowers(ctx context.Context, userID int64, page, pageSize int) ([]*User, error) { func (s *usersStore) ListFollowers(ctx context.Context, userID int64, page, pageSize int) ([]*User, error) {
/* /*
Equivalent SQL for PostgreSQL: Equivalent SQL for PostgreSQL:
@ -900,7 +900,7 @@ func (db *users) ListFollowers(ctx context.Context, userID int64, page, pageSize
LIMIT @limit OFFSET @offset LIMIT @limit OFFSET @offset
*/ */
users := make([]*User, 0, pageSize) users := make([]*User, 0, pageSize)
return users, db.WithContext(ctx). return users, s.WithContext(ctx).
Joins(dbutil.Quote("LEFT JOIN follow ON follow.user_id = %s.id", "user")). Joins(dbutil.Quote("LEFT JOIN follow ON follow.user_id = %s.id", "user")).
Where("follow.follow_id = ?", userID). Where("follow.follow_id = ?", userID).
Limit(pageSize).Offset((page - 1) * pageSize). Limit(pageSize).Offset((page - 1) * pageSize).
@ -909,7 +909,7 @@ func (db *users) ListFollowers(ctx context.Context, userID int64, page, pageSize
Error Error
} }
func (db *users) ListFollowings(ctx context.Context, userID int64, page, pageSize int) ([]*User, error) { func (s *usersStore) ListFollowings(ctx context.Context, userID int64, page, pageSize int) ([]*User, error) {
/* /*
Equivalent SQL for PostgreSQL: Equivalent SQL for PostgreSQL:
@ -920,7 +920,7 @@ func (db *users) ListFollowings(ctx context.Context, userID int64, page, pageSiz
LIMIT @limit OFFSET @offset LIMIT @limit OFFSET @offset
*/ */
users := make([]*User, 0, pageSize) users := make([]*User, 0, pageSize)
return users, db.WithContext(ctx). return users, s.WithContext(ctx).
Joins(dbutil.Quote("LEFT JOIN follow ON follow.follow_id = %s.id", "user")). Joins(dbutil.Quote("LEFT JOIN follow ON follow.follow_id = %s.id", "user")).
Where("follow.user_id = ?", userID). Where("follow.user_id = ?", userID).
Limit(pageSize).Offset((page - 1) * pageSize). Limit(pageSize).Offset((page - 1) * pageSize).
@ -948,8 +948,8 @@ func searchUserByName(ctx context.Context, db *gorm.DB, userType UserType, keywo
return users, count, tx.Order(orderBy).Limit(pageSize).Offset((page - 1) * pageSize).Find(&users).Error return users, count, tx.Order(orderBy).Limit(pageSize).Offset((page - 1) * pageSize).Find(&users).Error
} }
func (db *users) SearchByName(ctx context.Context, keyword string, page, pageSize int, orderBy string) ([]*User, int64, error) { func (s *usersStore) SearchByName(ctx context.Context, keyword string, page, pageSize int, orderBy string) ([]*User, int64, error) {
return searchUserByName(ctx, db.DB, UserTypeIndividual, keyword, page, pageSize, orderBy) return searchUserByName(ctx, s.DB, UserTypeIndividual, keyword, page, pageSize, orderBy)
} }
type UpdateUserOptions struct { type UpdateUserOptions struct {
@ -979,9 +979,9 @@ type UpdateUserOptions struct {
AvatarEmail *string AvatarEmail *string
} }
func (db *users) Update(ctx context.Context, userID int64, opts UpdateUserOptions) error { func (s *usersStore) Update(ctx context.Context, userID int64, opts UpdateUserOptions) error {
updates := map[string]any{ updates := map[string]any{
"updated_unix": db.NowFunc().Unix(), "updated_unix": s.NowFunc().Unix(),
} }
if opts.LoginSource != nil { if opts.LoginSource != nil {
@ -1012,7 +1012,7 @@ func (db *users) Update(ctx context.Context, userID int64, opts UpdateUserOption
updates["full_name"] = strutil.Truncate(*opts.FullName, 255) updates["full_name"] = strutil.Truncate(*opts.FullName, 255)
} }
if opts.Email != nil { if opts.Email != nil {
_, err := db.GetByEmail(ctx, *opts.Email) _, err := s.GetByEmail(ctx, *opts.Email)
if err == nil { if err == nil {
return ErrEmailAlreadyUsed{args: errutil.Args{"email": *opts.Email}} return ErrEmailAlreadyUsed{args: errutil.Args{"email": *opts.Email}}
} else if !IsErrUserNotExist(err) { } else if !IsErrUserNotExist(err) {
@ -1063,28 +1063,28 @@ func (db *users) Update(ctx context.Context, userID int64, opts UpdateUserOption
updates["avatar_email"] = strutil.Truncate(*opts.AvatarEmail, 255) updates["avatar_email"] = strutil.Truncate(*opts.AvatarEmail, 255)
} }
return db.WithContext(ctx).Model(&User{}).Where("id = ?", userID).Updates(updates).Error return s.WithContext(ctx).Model(&User{}).Where("id = ?", userID).Updates(updates).Error
} }
func (db *users) UseCustomAvatar(ctx context.Context, userID int64, avatar []byte) error { func (s *usersStore) UseCustomAvatar(ctx context.Context, userID int64, avatar []byte) error {
err := userutil.SaveAvatar(userID, avatar) err := userutil.SaveAvatar(userID, avatar)
if err != nil { if err != nil {
return errors.Wrap(err, "save avatar") return errors.Wrap(err, "save avatar")
} }
return db.WithContext(ctx). return s.WithContext(ctx).
Model(&User{}). Model(&User{}).
Where("id = ?", userID). Where("id = ?", userID).
Updates(map[string]any{ Updates(map[string]any{
"use_custom_avatar": true, "use_custom_avatar": true,
"updated_unix": db.NowFunc().Unix(), "updated_unix": s.NowFunc().Unix(),
}). }).
Error Error
} }
func (db *users) AddEmail(ctx context.Context, userID int64, email string, isActivated bool) error { func (s *usersStore) AddEmail(ctx context.Context, userID int64, email string, isActivated bool) error {
email = strings.ToLower(strings.TrimSpace(email)) email = strings.ToLower(strings.TrimSpace(email))
_, err := db.GetByEmail(ctx, email) _, err := s.GetByEmail(ctx, email)
if err == nil { if err == nil {
return ErrEmailAlreadyUsed{ return ErrEmailAlreadyUsed{
args: errutil.Args{ args: errutil.Args{
@ -1095,7 +1095,7 @@ func (db *users) AddEmail(ctx context.Context, userID int64, email string, isAct
return errors.Wrap(err, "check user by email") return errors.Wrap(err, "check user by email")
} }
return db.WithContext(ctx).Create( return s.WithContext(ctx).Create(
&EmailAddress{ &EmailAddress{
UserID: userID, UserID: userID,
Email: email, Email: email,
@ -1125,8 +1125,8 @@ func (ErrEmailNotExist) NotFound() bool {
return true return true
} }
func (db *users) GetEmail(ctx context.Context, userID int64, email string, needsActivated bool) (*EmailAddress, error) { func (s *usersStore) GetEmail(ctx context.Context, userID int64, email string, needsActivated bool) (*EmailAddress, error) {
tx := db.WithContext(ctx).Where("uid = ? AND email = ?", userID, email) tx := s.WithContext(ctx).Where("uid = ? AND email = ?", userID, email)
if needsActivated { if needsActivated {
tx = tx.Where("is_activated = ?", true) tx = tx.Where("is_activated = ?", true)
} }
@ -1146,14 +1146,14 @@ func (db *users) GetEmail(ctx context.Context, userID int64, email string, needs
return emailAddress, nil return emailAddress, nil
} }
func (db *users) ListEmails(ctx context.Context, userID int64) ([]*EmailAddress, error) { func (s *usersStore) ListEmails(ctx context.Context, userID int64) ([]*EmailAddress, error) {
user, err := db.GetByID(ctx, userID) user, err := s.GetByID(ctx, userID)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "get user") return nil, errors.Wrap(err, "get user")
} }
var emails []*EmailAddress var emails []*EmailAddress
err = db.WithContext(ctx).Where("uid = ?", userID).Order("id ASC").Find(&emails).Error err = s.WithContext(ctx).Where("uid = ?", userID).Order("id ASC").Find(&emails).Error
if err != nil { if err != nil {
return nil, errors.Wrap(err, "list emails") return nil, errors.Wrap(err, "list emails")
} }
@ -1179,9 +1179,9 @@ func (db *users) ListEmails(ctx context.Context, userID int64) ([]*EmailAddress,
return emails, nil return emails, nil
} }
func (db *users) MarkEmailActivated(ctx context.Context, userID int64, email string) error { func (s *usersStore) MarkEmailActivated(ctx context.Context, userID int64, email string) error {
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
err := db.WithContext(ctx). err := s.WithContext(ctx).
Model(&EmailAddress{}). Model(&EmailAddress{}).
Where("uid = ? AND email = ?", userID, email). Where("uid = ? AND email = ?", userID, email).
Update("is_activated", true). Update("is_activated", true).
@ -1209,9 +1209,9 @@ func (err ErrEmailNotVerified) Error() string {
return fmt.Sprintf("email has not been verified: %v", err.args) return fmt.Sprintf("email has not been verified: %v", err.args)
} }
func (db *users) MarkEmailPrimary(ctx context.Context, userID int64, email string) error { func (s *usersStore) MarkEmailPrimary(ctx context.Context, userID int64, email string) error {
var emailAddress EmailAddress var emailAddress EmailAddress
err := db.WithContext(ctx).Where("uid = ? AND email = ?", userID, email).First(&emailAddress).Error err := s.WithContext(ctx).Where("uid = ? AND email = ?", userID, email).First(&emailAddress).Error
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if err == gorm.ErrRecordNotFound {
return ErrEmailNotExist{args: errutil.Args{"email": email}} return ErrEmailNotExist{args: errutil.Args{"email": email}}
@ -1223,12 +1223,12 @@ func (db *users) MarkEmailPrimary(ctx context.Context, userID int64, email strin
return ErrEmailNotVerified{args: errutil.Args{"email": email}} return ErrEmailNotVerified{args: errutil.Args{"email": email}}
} }
user, err := db.GetByID(ctx, userID) user, err := s.GetByID(ctx, userID)
if err != nil { if err != nil {
return errors.Wrap(err, "get user") return errors.Wrap(err, "get user")
} }
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// Make sure the former primary email doesn't disappear. // Make sure the former primary email doesn't disappear.
err = tx.FirstOrCreate( err = tx.FirstOrCreate(
&EmailAddress{ &EmailAddress{
@ -1255,8 +1255,8 @@ func (db *users) MarkEmailPrimary(ctx context.Context, userID int64, email strin
}) })
} }
func (db *users) DeleteEmail(ctx context.Context, userID int64, email string) error { func (s *usersStore) DeleteEmail(ctx context.Context, userID int64, email string) error {
return db.WithContext(ctx).Where("uid = ? AND email = ?", userID, email).Delete(&EmailAddress{}).Error return s.WithContext(ctx).Where("uid = ? AND email = ?", userID, email).Delete(&EmailAddress{}).Error
} }
// UserType indicates the type of the user account. // UserType indicates the type of the user account.

View File

@ -84,13 +84,13 @@ func TestUsers(t *testing.T) {
t.Parallel() t.Parallel()
ctx := context.Background() ctx := context.Background()
db := &users{ db := &usersStore{
DB: newTestDB(t, "users"), DB: newTestDB(t, "usersStore"),
} }
for _, tc := range []struct { for _, tc := range []struct {
name string name string
test func(t *testing.T, ctx context.Context, db *users) test func(t *testing.T, ctx context.Context, db *usersStore)
}{ }{
{"Authenticate", usersAuthenticate}, {"Authenticate", usersAuthenticate},
{"ChangeUsername", usersChangeUsername}, {"ChangeUsername", usersChangeUsername},
@ -134,7 +134,7 @@ func TestUsers(t *testing.T) {
} }
} }
func usersAuthenticate(t *testing.T, ctx context.Context, db *users) { func usersAuthenticate(t *testing.T, ctx context.Context, db *usersStore) {
password := "pa$$word" password := "pa$$word"
alice, err := db.Create(ctx, "alice", "alice@example.com", alice, err := db.Create(ctx, "alice", "alice@example.com",
CreateUserOptions{ CreateUserOptions{
@ -229,7 +229,7 @@ func usersAuthenticate(t *testing.T, ctx context.Context, db *users) {
}) })
} }
func usersChangeUsername(t *testing.T, ctx context.Context, db *users) { func usersChangeUsername(t *testing.T, ctx context.Context, db *usersStore) {
alice, err := db.Create( alice, err := db.Create(
ctx, ctx,
"alice", "alice",
@ -359,7 +359,7 @@ func usersChangeUsername(t *testing.T, ctx context.Context, db *users) {
assert.Equal(t, strings.ToUpper(newUsername), alice.Name) assert.Equal(t, strings.ToUpper(newUsername), alice.Name)
} }
func usersCount(t *testing.T, ctx context.Context, db *users) { func usersCount(t *testing.T, ctx context.Context, db *usersStore) {
// Has no user initially // Has no user initially
got := db.Count(ctx) got := db.Count(ctx)
assert.Equal(t, int64(0), got) assert.Equal(t, int64(0), got)
@ -382,7 +382,7 @@ func usersCount(t *testing.T, ctx context.Context, db *users) {
assert.Equal(t, int64(1), got) assert.Equal(t, int64(1), got)
} }
func usersCreate(t *testing.T, ctx context.Context, db *users) { func usersCreate(t *testing.T, ctx context.Context, db *usersStore) {
alice, err := db.Create( alice, err := db.Create(
ctx, ctx,
"alice", "alice",
@ -430,7 +430,7 @@ func usersCreate(t *testing.T, ctx context.Context, db *users) {
assert.Equal(t, db.NowFunc().Format(time.RFC3339), user.Updated.UTC().Format(time.RFC3339)) assert.Equal(t, db.NowFunc().Format(time.RFC3339), user.Updated.UTC().Format(time.RFC3339))
} }
func usersDeleteCustomAvatar(t *testing.T, ctx context.Context, db *users) { func usersDeleteCustomAvatar(t *testing.T, ctx context.Context, db *usersStore) {
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -464,7 +464,7 @@ func usersDeleteCustomAvatar(t *testing.T, ctx context.Context, db *users) {
assert.False(t, alice.UseCustomAvatar) assert.False(t, alice.UseCustomAvatar)
} }
func usersDeleteByID(t *testing.T, ctx context.Context, db *users) { func usersDeleteByID(t *testing.T, ctx context.Context, db *usersStore) {
reposStore := NewReposStore(db.DB) reposStore := NewReposStore(db.DB)
t.Run("user still has repository ownership", func(t *testing.T) { t.Run("user still has repository ownership", func(t *testing.T) {
@ -674,7 +674,7 @@ func usersDeleteByID(t *testing.T, ctx context.Context, db *users) {
assert.Equal(t, wantErr, err) assert.Equal(t, wantErr, err)
} }
func usersDeleteInactivated(t *testing.T, ctx context.Context, db *users) { func usersDeleteInactivated(t *testing.T, ctx context.Context, db *usersStore) {
// User with repository ownership should be skipped // User with repository ownership should be skipped
alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{}) alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -720,7 +720,7 @@ func usersDeleteInactivated(t *testing.T, ctx context.Context, db *users) {
require.Len(t, users, 3) require.Len(t, users, 3)
} }
func usersGetByEmail(t *testing.T, ctx context.Context, db *users) { func usersGetByEmail(t *testing.T, ctx context.Context, db *usersStore) {
t.Run("empty email", func(t *testing.T) { t.Run("empty email", func(t *testing.T) {
_, err := db.GetByEmail(ctx, "") _, err := db.GetByEmail(ctx, "")
wantErr := ErrUserNotExist{args: errutil.Args{"email": ""}} wantErr := ErrUserNotExist{args: errutil.Args{"email": ""}}
@ -781,7 +781,7 @@ func usersGetByEmail(t *testing.T, ctx context.Context, db *users) {
}) })
} }
func usersGetByID(t *testing.T, ctx context.Context, db *users) { func usersGetByID(t *testing.T, ctx context.Context, db *usersStore) {
alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{}) alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -794,7 +794,7 @@ func usersGetByID(t *testing.T, ctx context.Context, db *users) {
assert.Equal(t, wantErr, err) assert.Equal(t, wantErr, err)
} }
func usersGetByUsername(t *testing.T, ctx context.Context, db *users) { func usersGetByUsername(t *testing.T, ctx context.Context, db *usersStore) {
alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{}) alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -807,7 +807,7 @@ func usersGetByUsername(t *testing.T, ctx context.Context, db *users) {
assert.Equal(t, wantErr, err) assert.Equal(t, wantErr, err)
} }
func usersGetByKeyID(t *testing.T, ctx context.Context, db *users) { func usersGetByKeyID(t *testing.T, ctx context.Context, db *usersStore) {
alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{}) alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -832,7 +832,7 @@ func usersGetByKeyID(t *testing.T, ctx context.Context, db *users) {
assert.Equal(t, wantErr, err) assert.Equal(t, wantErr, err)
} }
func usersGetMailableEmailsByUsernames(t *testing.T, ctx context.Context, db *users) { func usersGetMailableEmailsByUsernames(t *testing.T, ctx context.Context, db *usersStore) {
alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{}) alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
bob, err := db.Create(ctx, "bob", "bob@exmaple.com", CreateUserOptions{Activated: true}) bob, err := db.Create(ctx, "bob", "bob@exmaple.com", CreateUserOptions{Activated: true})
@ -846,7 +846,7 @@ func usersGetMailableEmailsByUsernames(t *testing.T, ctx context.Context, db *us
assert.Equal(t, want, got) assert.Equal(t, want, got)
} }
func usersIsUsernameUsed(t *testing.T, ctx context.Context, db *users) { func usersIsUsernameUsed(t *testing.T, ctx context.Context, db *usersStore) {
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -896,7 +896,7 @@ func usersIsUsernameUsed(t *testing.T, ctx context.Context, db *users) {
} }
} }
func usersList(t *testing.T, ctx context.Context, db *users) { func usersList(t *testing.T, ctx context.Context, db *usersStore) {
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
bob, err := db.Create(ctx, "bob", "bob@example.com", CreateUserOptions{}) bob, err := db.Create(ctx, "bob", "bob@example.com", CreateUserOptions{})
@ -929,7 +929,7 @@ func usersList(t *testing.T, ctx context.Context, db *users) {
assert.Equal(t, bob.ID, got[1].ID) assert.Equal(t, bob.ID, got[1].ID)
} }
func usersListFollowers(t *testing.T, ctx context.Context, db *users) { func usersListFollowers(t *testing.T, ctx context.Context, db *usersStore) {
john, err := db.Create(ctx, "john", "john@example.com", CreateUserOptions{}) john, err := db.Create(ctx, "john", "john@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -960,7 +960,7 @@ func usersListFollowers(t *testing.T, ctx context.Context, db *users) {
assert.Equal(t, alice.ID, got[0].ID) assert.Equal(t, alice.ID, got[0].ID)
} }
func usersListFollowings(t *testing.T, ctx context.Context, db *users) { func usersListFollowings(t *testing.T, ctx context.Context, db *usersStore) {
john, err := db.Create(ctx, "john", "john@example.com", CreateUserOptions{}) john, err := db.Create(ctx, "john", "john@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -991,7 +991,7 @@ func usersListFollowings(t *testing.T, ctx context.Context, db *users) {
assert.Equal(t, alice.ID, got[0].ID) assert.Equal(t, alice.ID, got[0].ID)
} }
func usersSearchByName(t *testing.T, ctx context.Context, db *users) { func usersSearchByName(t *testing.T, ctx context.Context, db *usersStore) {
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{FullName: "Alice Jordan"}) alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{FullName: "Alice Jordan"})
require.NoError(t, err) require.NoError(t, err)
bob, err := db.Create(ctx, "bob", "bob@example.com", CreateUserOptions{FullName: "Bob Jordan"}) bob, err := db.Create(ctx, "bob", "bob@example.com", CreateUserOptions{FullName: "Bob Jordan"})
@ -1029,7 +1029,7 @@ func usersSearchByName(t *testing.T, ctx context.Context, db *users) {
}) })
} }
func usersUpdate(t *testing.T, ctx context.Context, db *users) { func usersUpdate(t *testing.T, ctx context.Context, db *usersStore) {
const oldPassword = "Password" const oldPassword = "Password"
alice, err := db.Create( alice, err := db.Create(
ctx, ctx,
@ -1142,7 +1142,7 @@ func usersUpdate(t *testing.T, ctx context.Context, db *users) {
assertValues() assertValues()
} }
func usersUseCustomAvatar(t *testing.T, ctx context.Context, db *users) { func usersUseCustomAvatar(t *testing.T, ctx context.Context, db *usersStore) {
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -1180,7 +1180,7 @@ func TestIsUsernameAllowed(t *testing.T) {
} }
} }
func usersAddEmail(t *testing.T, ctx context.Context, db *users) { func usersAddEmail(t *testing.T, ctx context.Context, db *usersStore) {
t.Run("multiple users can add the same unverified email", func(t *testing.T) { t.Run("multiple users can add the same unverified email", func(t *testing.T) {
alice, err := db.Create(ctx, "alice", "unverified@example.com", CreateUserOptions{}) alice, err := db.Create(ctx, "alice", "unverified@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -1197,7 +1197,7 @@ func usersAddEmail(t *testing.T, ctx context.Context, db *users) {
}) })
} }
func usersGetEmail(t *testing.T, ctx context.Context, db *users) { func usersGetEmail(t *testing.T, ctx context.Context, db *usersStore) {
const testUserID = 1 const testUserID = 1
const testEmail = "alice@example.com" const testEmail = "alice@example.com"
_, err := db.GetEmail(ctx, testUserID, testEmail, false) _, err := db.GetEmail(ctx, testUserID, testEmail, false)
@ -1229,7 +1229,7 @@ func usersGetEmail(t *testing.T, ctx context.Context, db *users) {
assert.Equal(t, testEmail, got.Email) assert.Equal(t, testEmail, got.Email)
} }
func usersListEmails(t *testing.T, ctx context.Context, db *users) { func usersListEmails(t *testing.T, ctx context.Context, db *usersStore) {
t.Run("list emails with primary email", func(t *testing.T) { t.Run("list emails with primary email", func(t *testing.T) {
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -1265,7 +1265,7 @@ func usersListEmails(t *testing.T, ctx context.Context, db *users) {
}) })
} }
func usersMarkEmailActivated(t *testing.T, ctx context.Context, db *users) { func usersMarkEmailActivated(t *testing.T, ctx context.Context, db *usersStore) {
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -1283,7 +1283,7 @@ func usersMarkEmailActivated(t *testing.T, ctx context.Context, db *users) {
assert.NotEqual(t, alice.Rands, gotAlice.Rands) assert.NotEqual(t, alice.Rands, gotAlice.Rands)
} }
func usersMarkEmailPrimary(t *testing.T, ctx context.Context, db *users) { func usersMarkEmailPrimary(t *testing.T, ctx context.Context, db *usersStore) {
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
err = db.AddEmail(ctx, alice.ID, "alice2@example.com", false) err = db.AddEmail(ctx, alice.ID, "alice2@example.com", false)
@ -1309,7 +1309,7 @@ func usersMarkEmailPrimary(t *testing.T, ctx context.Context, db *users) {
assert.False(t, gotEmail.IsActivated) assert.False(t, gotEmail.IsActivated)
} }
func usersDeleteEmail(t *testing.T, ctx context.Context, db *users) { func usersDeleteEmail(t *testing.T, ctx context.Context, db *usersStore) {
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -1325,7 +1325,7 @@ func usersDeleteEmail(t *testing.T, ctx context.Context, db *users) {
require.Equal(t, want, got) require.Equal(t, want, got)
} }
func usersFollow(t *testing.T, ctx context.Context, db *users) { func usersFollow(t *testing.T, ctx context.Context, db *usersStore) {
usersStore := NewUsersStore(db.DB) usersStore := NewUsersStore(db.DB)
alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -1348,7 +1348,7 @@ func usersFollow(t *testing.T, ctx context.Context, db *users) {
assert.Equal(t, 1, bob.NumFollowers) assert.Equal(t, 1, bob.NumFollowers)
} }
func usersIsFollowing(t *testing.T, ctx context.Context, db *users) { func usersIsFollowing(t *testing.T, ctx context.Context, db *usersStore) {
usersStore := NewUsersStore(db.DB) usersStore := NewUsersStore(db.DB)
alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -1369,7 +1369,7 @@ func usersIsFollowing(t *testing.T, ctx context.Context, db *users) {
assert.False(t, got) assert.False(t, got)
} }
func usersUnfollow(t *testing.T, ctx context.Context, db *users) { func usersUnfollow(t *testing.T, ctx context.Context, db *usersStore) {
usersStore := NewUsersStore(db.DB) usersStore := NewUsersStore(db.DB)
alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)