mirror of
https://github.com/gogs/gogs.git
synced 2025-05-23 16:00:56 +00:00
internal/database: consistently use Store
and s
as receiver (#7669)
This commit is contained in:
parent
dfe27ad556
commit
917c14f2ce
@ -74,9 +74,9 @@ func (t *AccessToken) AfterFind(tx *gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ AccessTokensStore = (*accessTokens)(nil)
|
||||
var _ AccessTokensStore = (*accessTokensStore)(nil)
|
||||
|
||||
type accessTokens struct {
|
||||
type accessTokensStore struct {
|
||||
*gorm.DB
|
||||
}
|
||||
|
||||
@ -93,8 +93,8 @@ func (err ErrAccessTokenAlreadyExist) Error() string {
|
||||
return fmt.Sprintf("access token already exists: %v", err.args)
|
||||
}
|
||||
|
||||
func (db *accessTokens) Create(ctx context.Context, userID int64, name string) (*AccessToken, error) {
|
||||
err := db.WithContext(ctx).Where("uid = ? AND name = ?", userID, name).First(new(AccessToken)).Error
|
||||
func (s *accessTokensStore) Create(ctx context.Context, userID int64, name string) (*AccessToken, error) {
|
||||
err := s.WithContext(ctx).Where("uid = ? AND name = ?", userID, name).First(new(AccessToken)).Error
|
||||
if err == nil {
|
||||
return nil, ErrAccessTokenAlreadyExist{args: errutil.Args{"userID": userID, "name": name}}
|
||||
} else if err != gorm.ErrRecordNotFound {
|
||||
@ -110,7 +110,7 @@ func (db *accessTokens) Create(ctx context.Context, userID int64, name string) (
|
||||
Sha1: sha256[:40], // To pass the column unique constraint, keep the length of SHA1.
|
||||
SHA256: sha256,
|
||||
}
|
||||
if err = db.WithContext(ctx).Create(accessToken).Error; err != nil {
|
||||
if err = s.WithContext(ctx).Create(accessToken).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -119,8 +119,8 @@ func (db *accessTokens) Create(ctx context.Context, userID int64, name string) (
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func (db *accessTokens) DeleteByID(ctx context.Context, userID, id int64) error {
|
||||
return db.WithContext(ctx).Where("id = ? AND uid = ?", id, userID).Delete(new(AccessToken)).Error
|
||||
func (s *accessTokensStore) DeleteByID(ctx context.Context, userID, id int64) error {
|
||||
return s.WithContext(ctx).Where("id = ? AND uid = ?", id, userID).Delete(new(AccessToken)).Error
|
||||
}
|
||||
|
||||
var _ errutil.NotFound = (*ErrAccessTokenNotExist)(nil)
|
||||
@ -144,7 +144,7 @@ func (ErrAccessTokenNotExist) NotFound() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (db *accessTokens) GetBySHA1(ctx context.Context, sha1 string) (*AccessToken, error) {
|
||||
func (s *accessTokensStore) GetBySHA1(ctx context.Context, sha1 string) (*AccessToken, error) {
|
||||
// No need to waste a query for an empty SHA1.
|
||||
if sha1 == "" {
|
||||
return nil, ErrAccessTokenNotExist{args: errutil.Args{"sha": sha1}}
|
||||
@ -152,7 +152,7 @@ func (db *accessTokens) GetBySHA1(ctx context.Context, sha1 string) (*AccessToke
|
||||
|
||||
sha256 := cryptoutil.SHA256(sha1)
|
||||
token := new(AccessToken)
|
||||
err := db.WithContext(ctx).Where("sha256 = ?", sha256).First(token).Error
|
||||
err := s.WithContext(ctx).Where("sha256 = ?", sha256).First(token).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, ErrAccessTokenNotExist{args: errutil.Args{"sha": sha1}}
|
||||
@ -162,15 +162,15 @@ func (db *accessTokens) GetBySHA1(ctx context.Context, sha1 string) (*AccessToke
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (db *accessTokens) List(ctx context.Context, userID int64) ([]*AccessToken, error) {
|
||||
func (s *accessTokensStore) List(ctx context.Context, userID int64) ([]*AccessToken, error) {
|
||||
var tokens []*AccessToken
|
||||
return tokens, db.WithContext(ctx).Where("uid = ?", userID).Order("id ASC").Find(&tokens).Error
|
||||
return tokens, s.WithContext(ctx).Where("uid = ?", userID).Order("id ASC").Find(&tokens).Error
|
||||
}
|
||||
|
||||
func (db *accessTokens) Touch(ctx context.Context, id int64) error {
|
||||
return db.WithContext(ctx).
|
||||
func (s *accessTokensStore) Touch(ctx context.Context, id int64) error {
|
||||
return s.WithContext(ctx).
|
||||
Model(new(AccessToken)).
|
||||
Where("id = ?", id).
|
||||
UpdateColumn("updated_unix", db.NowFunc().Unix()).
|
||||
UpdateColumn("updated_unix", s.NowFunc().Unix()).
|
||||
Error
|
||||
}
|
||||
|
@ -98,13 +98,13 @@ func TestAccessTokens(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
db := &accessTokens{
|
||||
DB: newTestDB(t, "accessTokens"),
|
||||
db := &accessTokensStore{
|
||||
DB: newTestDB(t, "accessTokensStore"),
|
||||
}
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
test func(t *testing.T, ctx context.Context, db *accessTokens)
|
||||
test func(t *testing.T, ctx context.Context, db *accessTokensStore)
|
||||
}{
|
||||
{"Create", accessTokensCreate},
|
||||
{"DeleteByID", accessTokensDeleteByID},
|
||||
@ -125,7 +125,7 @@ func TestAccessTokens(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func accessTokensCreate(t *testing.T, ctx context.Context, db *accessTokens) {
|
||||
func accessTokensCreate(t *testing.T, ctx context.Context, db *accessTokensStore) {
|
||||
// Create first access token with name "Test"
|
||||
token, err := db.Create(ctx, 1, "Test")
|
||||
require.NoError(t, err)
|
||||
@ -150,7 +150,7 @@ func accessTokensCreate(t *testing.T, ctx context.Context, db *accessTokens) {
|
||||
assert.Equal(t, wantErr, err)
|
||||
}
|
||||
|
||||
func accessTokensDeleteByID(t *testing.T, ctx context.Context, db *accessTokens) {
|
||||
func accessTokensDeleteByID(t *testing.T, ctx context.Context, db *accessTokensStore) {
|
||||
// Create an access token with name "Test"
|
||||
token, err := db.Create(ctx, 1, "Test")
|
||||
require.NoError(t, err)
|
||||
@ -177,7 +177,7 @@ func accessTokensDeleteByID(t *testing.T, ctx context.Context, db *accessTokens)
|
||||
assert.Equal(t, wantErr, err)
|
||||
}
|
||||
|
||||
func accessTokensGetBySHA(t *testing.T, ctx context.Context, db *accessTokens) {
|
||||
func accessTokensGetBySHA(t *testing.T, ctx context.Context, db *accessTokensStore) {
|
||||
// Create an access token with name "Test"
|
||||
token, err := db.Create(ctx, 1, "Test")
|
||||
require.NoError(t, err)
|
||||
@ -196,7 +196,7 @@ func accessTokensGetBySHA(t *testing.T, ctx context.Context, db *accessTokens) {
|
||||
assert.Equal(t, wantErr, err)
|
||||
}
|
||||
|
||||
func accessTokensList(t *testing.T, ctx context.Context, db *accessTokens) {
|
||||
func accessTokensList(t *testing.T, ctx context.Context, db *accessTokensStore) {
|
||||
// Create two access tokens for user 1
|
||||
_, err := db.Create(ctx, 1, "user1_1")
|
||||
require.NoError(t, err)
|
||||
@ -219,7 +219,7 @@ func accessTokensList(t *testing.T, ctx context.Context, db *accessTokens) {
|
||||
assert.Equal(t, "user1_2", tokens[1].Name)
|
||||
}
|
||||
|
||||
func accessTokensTouch(t *testing.T, ctx context.Context, db *accessTokens) {
|
||||
func accessTokensTouch(t *testing.T, ctx context.Context, db *accessTokensStore) {
|
||||
// Create an access token with name "Test"
|
||||
token, err := db.Create(ctx, 1, "Test")
|
||||
require.NoError(t, err)
|
||||
|
@ -70,19 +70,19 @@ type ActionsStore interface {
|
||||
|
||||
var Actions ActionsStore
|
||||
|
||||
var _ ActionsStore = (*actions)(nil)
|
||||
var _ ActionsStore = (*actionsStore)(nil)
|
||||
|
||||
type actions struct {
|
||||
type actionsStore struct {
|
||||
*gorm.DB
|
||||
}
|
||||
|
||||
// NewActionsStore returns a persistent interface for actions with given
|
||||
// database connection.
|
||||
func NewActionsStore(db *gorm.DB) ActionsStore {
|
||||
return &actions{DB: db}
|
||||
return &actionsStore{DB: db}
|
||||
}
|
||||
|
||||
func (db *actions) listByOrganization(ctx context.Context, orgID, actorID, afterID int64) *gorm.DB {
|
||||
func (s *actionsStore) listByOrganization(ctx context.Context, orgID, actorID, afterID int64) *gorm.DB {
|
||||
/*
|
||||
Equivalent SQL for PostgreSQL:
|
||||
|
||||
@ -102,18 +102,18 @@ func (db *actions) listByOrganization(ctx context.Context, orgID, actorID, after
|
||||
ORDER BY id DESC
|
||||
LIMIT @limit
|
||||
*/
|
||||
return db.WithContext(ctx).
|
||||
return s.WithContext(ctx).
|
||||
Where("user_id = ?", orgID).
|
||||
Where(db.
|
||||
Where(s.
|
||||
// Not apply when afterID is not given
|
||||
Where("?", afterID <= 0).
|
||||
Or("id < ?", afterID),
|
||||
).
|
||||
Where("repo_id IN (?)", db.
|
||||
Where("repo_id IN (?)", s.
|
||||
Select("repository.id").
|
||||
Table("repository").
|
||||
Joins("JOIN team_repo ON repository.id = team_repo.repo_id").
|
||||
Where("team_repo.team_id IN (?)", db.
|
||||
Where("team_repo.team_id IN (?)", s.
|
||||
Select("team_id").
|
||||
Table("team_user").
|
||||
Where("team_user.org_id = ? AND uid = ?", orgID, actorID),
|
||||
@ -124,12 +124,12 @@ func (db *actions) listByOrganization(ctx context.Context, orgID, actorID, after
|
||||
Order("id DESC")
|
||||
}
|
||||
|
||||
func (db *actions) ListByOrganization(ctx context.Context, orgID, actorID, afterID int64) ([]*Action, error) {
|
||||
func (s *actionsStore) ListByOrganization(ctx context.Context, orgID, actorID, afterID int64) ([]*Action, error) {
|
||||
actions := make([]*Action, 0, conf.UI.User.NewsFeedPagingNum)
|
||||
return actions, db.listByOrganization(ctx, orgID, actorID, afterID).Find(&actions).Error
|
||||
return actions, s.listByOrganization(ctx, orgID, actorID, afterID).Find(&actions).Error
|
||||
}
|
||||
|
||||
func (db *actions) listByUser(ctx context.Context, userID, actorID, afterID int64, isProfile bool) *gorm.DB {
|
||||
func (s *actionsStore) listByUser(ctx context.Context, userID, actorID, afterID int64, isProfile bool) *gorm.DB {
|
||||
/*
|
||||
Equivalent SQL for PostgreSQL:
|
||||
|
||||
@ -141,14 +141,14 @@ func (db *actions) listByUser(ctx context.Context, userID, actorID, afterID int6
|
||||
ORDER BY id DESC
|
||||
LIMIT @limit
|
||||
*/
|
||||
return db.WithContext(ctx).
|
||||
return s.WithContext(ctx).
|
||||
Where("user_id = ?", userID).
|
||||
Where(db.
|
||||
Where(s.
|
||||
// Not apply when afterID is not given
|
||||
Where("?", afterID <= 0).
|
||||
Or("id < ?", afterID),
|
||||
).
|
||||
Where(db.
|
||||
Where(s.
|
||||
// Not apply when in not profile page or the user is viewing own profile
|
||||
Where("?", !isProfile || actorID == userID).
|
||||
Or("is_private = ? AND act_user_id = ?", false, userID),
|
||||
@ -157,14 +157,14 @@ func (db *actions) listByUser(ctx context.Context, userID, actorID, afterID int6
|
||||
Order("id DESC")
|
||||
}
|
||||
|
||||
func (db *actions) ListByUser(ctx context.Context, userID, actorID, afterID int64, isProfile bool) ([]*Action, error) {
|
||||
func (s *actionsStore) ListByUser(ctx context.Context, userID, actorID, afterID int64, isProfile bool) ([]*Action, error) {
|
||||
actions := make([]*Action, 0, conf.UI.User.NewsFeedPagingNum)
|
||||
return actions, db.listByUser(ctx, userID, actorID, afterID, isProfile).Find(&actions).Error
|
||||
return actions, s.listByUser(ctx, userID, actorID, afterID, isProfile).Find(&actions).Error
|
||||
}
|
||||
|
||||
// notifyWatchers creates rows in action table for watchers who are able to see the action.
|
||||
func (db *actions) notifyWatchers(ctx context.Context, act *Action) error {
|
||||
watches, err := NewReposStore(db.DB).ListWatches(ctx, act.RepoID)
|
||||
func (s *actionsStore) notifyWatchers(ctx context.Context, act *Action) error {
|
||||
watches, err := NewReposStore(s.DB).ListWatches(ctx, act.RepoID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "list watches")
|
||||
}
|
||||
@ -187,16 +187,16 @@ func (db *actions) notifyWatchers(ctx context.Context, act *Action) error {
|
||||
actions = append(actions, clone(watch.UserID))
|
||||
}
|
||||
|
||||
return db.Create(actions).Error
|
||||
return s.Create(actions).Error
|
||||
}
|
||||
|
||||
func (db *actions) NewRepo(ctx context.Context, doer, owner *User, repo *Repository) error {
|
||||
func (s *actionsStore) NewRepo(ctx context.Context, doer, owner *User, repo *Repository) error {
|
||||
opType := ActionCreateRepo
|
||||
if repo.IsFork {
|
||||
opType = ActionForkRepo
|
||||
}
|
||||
|
||||
return db.notifyWatchers(ctx,
|
||||
return s.notifyWatchers(ctx,
|
||||
&Action{
|
||||
ActUserID: doer.ID,
|
||||
ActUserName: doer.Name,
|
||||
@ -209,8 +209,8 @@ func (db *actions) NewRepo(ctx context.Context, doer, owner *User, repo *Reposit
|
||||
)
|
||||
}
|
||||
|
||||
func (db *actions) RenameRepo(ctx context.Context, doer, owner *User, oldRepoName string, repo *Repository) error {
|
||||
return db.notifyWatchers(ctx,
|
||||
func (s *actionsStore) RenameRepo(ctx context.Context, doer, owner *User, oldRepoName string, repo *Repository) error {
|
||||
return s.notifyWatchers(ctx,
|
||||
&Action{
|
||||
ActUserID: doer.ID,
|
||||
ActUserName: doer.Name,
|
||||
@ -224,8 +224,8 @@ func (db *actions) RenameRepo(ctx context.Context, doer, owner *User, oldRepoNam
|
||||
)
|
||||
}
|
||||
|
||||
func (db *actions) mirrorSyncAction(ctx context.Context, opType ActionType, owner *User, repo *Repository, refName string, content []byte) error {
|
||||
return db.notifyWatchers(ctx,
|
||||
func (s *actionsStore) mirrorSyncAction(ctx context.Context, opType ActionType, owner *User, repo *Repository, refName string, content []byte) error {
|
||||
return s.notifyWatchers(ctx,
|
||||
&Action{
|
||||
ActUserID: owner.ID,
|
||||
ActUserName: owner.Name,
|
||||
@ -249,13 +249,13 @@ type MirrorSyncPushOptions struct {
|
||||
Commits *PushCommits
|
||||
}
|
||||
|
||||
func (db *actions) MirrorSyncPush(ctx context.Context, opts MirrorSyncPushOptions) error {
|
||||
func (s *actionsStore) MirrorSyncPush(ctx context.Context, opts MirrorSyncPushOptions) error {
|
||||
if conf.UI.FeedMaxCommitNum > 0 && len(opts.Commits.Commits) > conf.UI.FeedMaxCommitNum {
|
||||
opts.Commits.Commits = opts.Commits.Commits[:conf.UI.FeedMaxCommitNum]
|
||||
}
|
||||
|
||||
apiCommits, err := opts.Commits.APIFormat(ctx,
|
||||
NewUsersStore(db.DB),
|
||||
NewUsersStore(s.DB),
|
||||
repoutil.RepositoryPath(opts.Owner.Name, opts.Repo.Name),
|
||||
repoutil.HTMLURL(opts.Owner.Name, opts.Repo.Name),
|
||||
)
|
||||
@ -288,19 +288,19 @@ func (db *actions) MirrorSyncPush(ctx context.Context, opts MirrorSyncPushOption
|
||||
return errors.Wrap(err, "marshal JSON")
|
||||
}
|
||||
|
||||
return db.mirrorSyncAction(ctx, ActionMirrorSyncPush, opts.Owner, opts.Repo, opts.RefName, data)
|
||||
return s.mirrorSyncAction(ctx, ActionMirrorSyncPush, opts.Owner, opts.Repo, opts.RefName, data)
|
||||
}
|
||||
|
||||
func (db *actions) MirrorSyncCreate(ctx context.Context, owner *User, repo *Repository, refName string) error {
|
||||
return db.mirrorSyncAction(ctx, ActionMirrorSyncCreate, owner, repo, refName, nil)
|
||||
func (s *actionsStore) MirrorSyncCreate(ctx context.Context, owner *User, repo *Repository, refName string) error {
|
||||
return s.mirrorSyncAction(ctx, ActionMirrorSyncCreate, owner, repo, refName, nil)
|
||||
}
|
||||
|
||||
func (db *actions) MirrorSyncDelete(ctx context.Context, owner *User, repo *Repository, refName string) error {
|
||||
return db.mirrorSyncAction(ctx, ActionMirrorSyncDelete, owner, repo, refName, nil)
|
||||
func (s *actionsStore) MirrorSyncDelete(ctx context.Context, owner *User, repo *Repository, refName string) error {
|
||||
return s.mirrorSyncAction(ctx, ActionMirrorSyncDelete, owner, repo, refName, nil)
|
||||
}
|
||||
|
||||
func (db *actions) MergePullRequest(ctx context.Context, doer, owner *User, repo *Repository, pull *Issue) error {
|
||||
return db.notifyWatchers(ctx,
|
||||
func (s *actionsStore) MergePullRequest(ctx context.Context, doer, owner *User, repo *Repository, pull *Issue) error {
|
||||
return s.notifyWatchers(ctx,
|
||||
&Action{
|
||||
ActUserID: doer.ID,
|
||||
ActUserName: doer.Name,
|
||||
@ -314,8 +314,8 @@ func (db *actions) MergePullRequest(ctx context.Context, doer, owner *User, repo
|
||||
)
|
||||
}
|
||||
|
||||
func (db *actions) TransferRepo(ctx context.Context, doer, oldOwner, newOwner *User, repo *Repository) error {
|
||||
return db.notifyWatchers(ctx,
|
||||
func (s *actionsStore) TransferRepo(ctx context.Context, doer, oldOwner, newOwner *User, repo *Repository) error {
|
||||
return s.notifyWatchers(ctx,
|
||||
&Action{
|
||||
ActUserID: doer.ID,
|
||||
ActUserName: doer.Name,
|
||||
@ -487,13 +487,13 @@ type CommitRepoOptions struct {
|
||||
Commits *PushCommits
|
||||
}
|
||||
|
||||
func (db *actions) CommitRepo(ctx context.Context, opts CommitRepoOptions) error {
|
||||
err := NewReposStore(db.DB).Touch(ctx, opts.Repo.ID)
|
||||
func (s *actionsStore) CommitRepo(ctx context.Context, opts CommitRepoOptions) error {
|
||||
err := NewReposStore(s.DB).Touch(ctx, opts.Repo.ID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "touch repository")
|
||||
}
|
||||
|
||||
pusher, err := NewUsersStore(db.DB).GetByUsername(ctx, opts.PusherName)
|
||||
pusher, err := NewUsersStore(s.DB).GetByUsername(ctx, opts.PusherName)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "get pusher [name: %s]", opts.PusherName)
|
||||
}
|
||||
@ -536,7 +536,7 @@ func (db *actions) CommitRepo(ctx context.Context, opts CommitRepoOptions) error
|
||||
}
|
||||
|
||||
action.OpType = ActionDeleteBranch
|
||||
err = db.notifyWatchers(ctx, action)
|
||||
err = s.notifyWatchers(ctx, action)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "notify watchers")
|
||||
}
|
||||
@ -580,7 +580,7 @@ func (db *actions) CommitRepo(ctx context.Context, opts CommitRepoOptions) error
|
||||
}
|
||||
|
||||
action.OpType = ActionCreateBranch
|
||||
err = db.notifyWatchers(ctx, action)
|
||||
err = s.notifyWatchers(ctx, action)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "notify watchers")
|
||||
}
|
||||
@ -589,7 +589,7 @@ func (db *actions) CommitRepo(ctx context.Context, opts CommitRepoOptions) error
|
||||
}
|
||||
|
||||
commits, err := opts.Commits.APIFormat(ctx,
|
||||
NewUsersStore(db.DB),
|
||||
NewUsersStore(s.DB),
|
||||
repoutil.RepositoryPath(opts.Owner.Name, opts.Repo.Name),
|
||||
repoutil.HTMLURL(opts.Owner.Name, opts.Repo.Name),
|
||||
)
|
||||
@ -616,7 +616,7 @@ func (db *actions) CommitRepo(ctx context.Context, opts CommitRepoOptions) error
|
||||
}
|
||||
|
||||
action.OpType = ActionCommitRepo
|
||||
err = db.notifyWatchers(ctx, action)
|
||||
err = s.notifyWatchers(ctx, action)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "notify watchers")
|
||||
}
|
||||
@ -631,13 +631,13 @@ type PushTagOptions struct {
|
||||
NewCommitID string
|
||||
}
|
||||
|
||||
func (db *actions) PushTag(ctx context.Context, opts PushTagOptions) error {
|
||||
err := NewReposStore(db.DB).Touch(ctx, opts.Repo.ID)
|
||||
func (s *actionsStore) PushTag(ctx context.Context, opts PushTagOptions) error {
|
||||
err := NewReposStore(s.DB).Touch(ctx, opts.Repo.ID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "touch repository")
|
||||
}
|
||||
|
||||
pusher, err := NewUsersStore(db.DB).GetByUsername(ctx, opts.PusherName)
|
||||
pusher, err := NewUsersStore(s.DB).GetByUsername(ctx, opts.PusherName)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "get pusher [name: %s]", opts.PusherName)
|
||||
}
|
||||
@ -672,7 +672,7 @@ func (db *actions) PushTag(ctx context.Context, opts PushTagOptions) error {
|
||||
}
|
||||
|
||||
action.OpType = ActionDeleteTag
|
||||
err = db.notifyWatchers(ctx, action)
|
||||
err = s.notifyWatchers(ctx, action)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "notify watchers")
|
||||
}
|
||||
@ -696,7 +696,7 @@ func (db *actions) PushTag(ctx context.Context, opts PushTagOptions) error {
|
||||
}
|
||||
|
||||
action.OpType = ActionPushTag
|
||||
err = db.notifyWatchers(ctx, action)
|
||||
err = s.notifyWatchers(ctx, action)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "notify watchers")
|
||||
}
|
||||
|
@ -99,13 +99,13 @@ func TestActions(t *testing.T) {
|
||||
|
||||
ctx := context.Background()
|
||||
t.Parallel()
|
||||
db := &actions{
|
||||
DB: newTestDB(t, "actions"),
|
||||
db := &actionsStore{
|
||||
DB: newTestDB(t, "actionsStore"),
|
||||
}
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
test func(t *testing.T, ctx context.Context, db *actions)
|
||||
test func(t *testing.T, ctx context.Context, db *actionsStore)
|
||||
}{
|
||||
{"CommitRepo", actionsCommitRepo},
|
||||
{"ListByOrganization", actionsListByOrganization},
|
||||
@ -132,7 +132,7 @@ func TestActions(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func actionsCommitRepo(t *testing.T, ctx context.Context, db *actions) {
|
||||
func actionsCommitRepo(t *testing.T, ctx context.Context, db *actionsStore) {
|
||||
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
|
||||
require.NoError(t, err)
|
||||
repo, err := NewReposStore(db.DB).Create(ctx,
|
||||
@ -324,7 +324,7 @@ func actionsCommitRepo(t *testing.T, ctx context.Context, db *actions) {
|
||||
})
|
||||
}
|
||||
|
||||
func actionsListByOrganization(t *testing.T, ctx context.Context, db *actions) {
|
||||
func actionsListByOrganization(t *testing.T, ctx context.Context, db *actionsStore) {
|
||||
if os.Getenv("GOGS_DATABASE_TYPE") != "postgres" {
|
||||
t.Skip("Skipping testing with not using PostgreSQL")
|
||||
return
|
||||
@ -363,14 +363,14 @@ func actionsListByOrganization(t *testing.T, ctx context.Context, db *actions) {
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
got := db.DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
return NewActionsStore(tx).(*actions).listByOrganization(ctx, test.orgID, test.actorID, test.afterID).Find(new(Action))
|
||||
return NewActionsStore(tx).(*actionsStore).listByOrganization(ctx, test.orgID, test.actorID, test.afterID).Find(new(Action))
|
||||
})
|
||||
assert.Equal(t, test.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func actionsListByUser(t *testing.T, ctx context.Context, db *actions) {
|
||||
func actionsListByUser(t *testing.T, ctx context.Context, db *actionsStore) {
|
||||
if os.Getenv("GOGS_DATABASE_TYPE") != "postgres" {
|
||||
t.Skip("Skipping testing with not using PostgreSQL")
|
||||
return
|
||||
@ -428,14 +428,14 @@ func actionsListByUser(t *testing.T, ctx context.Context, db *actions) {
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
got := db.DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
return NewActionsStore(tx).(*actions).listByUser(ctx, test.userID, test.actorID, test.afterID, test.isProfile).Find(new(Action))
|
||||
return NewActionsStore(tx).(*actionsStore).listByUser(ctx, test.userID, test.actorID, test.afterID, test.isProfile).Find(new(Action))
|
||||
})
|
||||
assert.Equal(t, test.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func actionsMergePullRequest(t *testing.T, ctx context.Context, db *actions) {
|
||||
func actionsMergePullRequest(t *testing.T, ctx context.Context, db *actionsStore) {
|
||||
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
|
||||
require.NoError(t, err)
|
||||
repo, err := NewReposStore(db.DB).Create(ctx,
|
||||
@ -480,7 +480,7 @@ func actionsMergePullRequest(t *testing.T, ctx context.Context, db *actions) {
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func actionsMirrorSyncCreate(t *testing.T, ctx context.Context, db *actions) {
|
||||
func actionsMirrorSyncCreate(t *testing.T, ctx context.Context, db *actionsStore) {
|
||||
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
|
||||
require.NoError(t, err)
|
||||
repo, err := NewReposStore(db.DB).Create(ctx,
|
||||
@ -521,7 +521,7 @@ func actionsMirrorSyncCreate(t *testing.T, ctx context.Context, db *actions) {
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func actionsMirrorSyncDelete(t *testing.T, ctx context.Context, db *actions) {
|
||||
func actionsMirrorSyncDelete(t *testing.T, ctx context.Context, db *actionsStore) {
|
||||
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
|
||||
require.NoError(t, err)
|
||||
repo, err := NewReposStore(db.DB).Create(ctx,
|
||||
@ -562,7 +562,7 @@ func actionsMirrorSyncDelete(t *testing.T, ctx context.Context, db *actions) {
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func actionsMirrorSyncPush(t *testing.T, ctx context.Context, db *actions) {
|
||||
func actionsMirrorSyncPush(t *testing.T, ctx context.Context, db *actionsStore) {
|
||||
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
|
||||
require.NoError(t, err)
|
||||
repo, err := NewReposStore(db.DB).Create(ctx,
|
||||
@ -627,7 +627,7 @@ func actionsMirrorSyncPush(t *testing.T, ctx context.Context, db *actions) {
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func actionsNewRepo(t *testing.T, ctx context.Context, db *actions) {
|
||||
func actionsNewRepo(t *testing.T, ctx context.Context, db *actionsStore) {
|
||||
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
|
||||
require.NoError(t, err)
|
||||
repo, err := NewReposStore(db.DB).Create(ctx,
|
||||
@ -702,7 +702,7 @@ func actionsNewRepo(t *testing.T, ctx context.Context, db *actions) {
|
||||
})
|
||||
}
|
||||
|
||||
func actionsPushTag(t *testing.T, ctx context.Context, db *actions) {
|
||||
func actionsPushTag(t *testing.T, ctx context.Context, db *actionsStore) {
|
||||
// NOTE: We set a noop mock here to avoid data race with other tests that writes
|
||||
// to the mock server because this function holds a lock.
|
||||
conf.SetMockServer(t, conf.ServerOpts{})
|
||||
@ -798,7 +798,7 @@ func actionsPushTag(t *testing.T, ctx context.Context, db *actions) {
|
||||
})
|
||||
}
|
||||
|
||||
func actionsRenameRepo(t *testing.T, ctx context.Context, db *actions) {
|
||||
func actionsRenameRepo(t *testing.T, ctx context.Context, db *actionsStore) {
|
||||
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
|
||||
require.NoError(t, err)
|
||||
repo, err := NewReposStore(db.DB).Create(ctx,
|
||||
@ -835,7 +835,7 @@ func actionsRenameRepo(t *testing.T, ctx context.Context, db *actions) {
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func actionsTransferRepo(t *testing.T, ctx context.Context, db *actions) {
|
||||
func actionsTransferRepo(t *testing.T, ctx context.Context, db *actionsStore) {
|
||||
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
|
||||
require.NoError(t, err)
|
||||
bob, err := NewUsersStore(db.DB).Create(ctx, "bob", "bob@example.com", CreateUserOptions{})
|
||||
|
@ -123,15 +123,15 @@ func Init(w logger.Writer) (*gorm.DB, error) {
|
||||
}
|
||||
|
||||
// Initialize stores, sorted in alphabetical order.
|
||||
AccessTokens = &accessTokens{DB: db}
|
||||
AccessTokens = &accessTokensStore{DB: db}
|
||||
Actions = NewActionsStore(db)
|
||||
LoginSources = &loginSources{DB: db, files: sourceFiles}
|
||||
LFS = &lfs{DB: db}
|
||||
LoginSources = &loginSourcesStore{DB: db, files: sourceFiles}
|
||||
LFS = &lfsStore{DB: db}
|
||||
Notices = NewNoticesStore(db)
|
||||
Orgs = NewOrgsStore(db)
|
||||
Perms = NewPermsStore(db)
|
||||
Repos = NewReposStore(db)
|
||||
TwoFactors = &twoFactors{DB: db}
|
||||
TwoFactors = &twoFactorsStore{DB: db}
|
||||
Users = NewUsersStore(db)
|
||||
|
||||
return db, nil
|
||||
|
@ -38,20 +38,20 @@ type LFSObject struct {
|
||||
CreatedAt time.Time `gorm:"not null"`
|
||||
}
|
||||
|
||||
var _ LFSStore = (*lfs)(nil)
|
||||
var _ LFSStore = (*lfsStore)(nil)
|
||||
|
||||
type lfs struct {
|
||||
type lfsStore struct {
|
||||
*gorm.DB
|
||||
}
|
||||
|
||||
func (db *lfs) CreateObject(ctx context.Context, repoID int64, oid lfsutil.OID, size int64, storage lfsutil.Storage) error {
|
||||
func (s *lfsStore) CreateObject(ctx context.Context, repoID int64, oid lfsutil.OID, size int64, storage lfsutil.Storage) error {
|
||||
object := &LFSObject{
|
||||
RepoID: repoID,
|
||||
OID: oid,
|
||||
Size: size,
|
||||
Storage: storage,
|
||||
}
|
||||
return db.WithContext(ctx).Create(object).Error
|
||||
return s.WithContext(ctx).Create(object).Error
|
||||
}
|
||||
|
||||
type ErrLFSObjectNotExist struct {
|
||||
@ -71,9 +71,9 @@ func (ErrLFSObjectNotExist) NotFound() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (db *lfs) GetObjectByOID(ctx context.Context, repoID int64, oid lfsutil.OID) (*LFSObject, error) {
|
||||
func (s *lfsStore) GetObjectByOID(ctx context.Context, repoID int64, oid lfsutil.OID) (*LFSObject, error) {
|
||||
object := new(LFSObject)
|
||||
err := db.WithContext(ctx).Where("repo_id = ? AND oid = ?", repoID, oid).First(object).Error
|
||||
err := s.WithContext(ctx).Where("repo_id = ? AND oid = ?", repoID, oid).First(object).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, ErrLFSObjectNotExist{args: errutil.Args{"repoID": repoID, "oid": oid}}
|
||||
@ -83,13 +83,13 @@ func (db *lfs) GetObjectByOID(ctx context.Context, repoID int64, oid lfsutil.OID
|
||||
return object, err
|
||||
}
|
||||
|
||||
func (db *lfs) GetObjectsByOIDs(ctx context.Context, repoID int64, oids ...lfsutil.OID) ([]*LFSObject, error) {
|
||||
func (s *lfsStore) GetObjectsByOIDs(ctx context.Context, repoID int64, oids ...lfsutil.OID) ([]*LFSObject, error) {
|
||||
if len(oids) == 0 {
|
||||
return []*LFSObject{}, nil
|
||||
}
|
||||
|
||||
objects := make([]*LFSObject, 0, len(oids))
|
||||
err := db.WithContext(ctx).Where("repo_id = ? AND oid IN (?)", repoID, oids).Find(&objects).Error
|
||||
err := s.WithContext(ctx).Where("repo_id = ? AND oid IN (?)", repoID, oids).Find(&objects).Error
|
||||
if err != nil && err != gorm.ErrRecordNotFound {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -23,13 +23,13 @@ func TestLFS(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
db := &lfs{
|
||||
DB: newTestDB(t, "lfs"),
|
||||
db := &lfsStore{
|
||||
DB: newTestDB(t, "lfsStore"),
|
||||
}
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
test func(t *testing.T, ctx context.Context, db *lfs)
|
||||
test func(t *testing.T, ctx context.Context, db *lfsStore)
|
||||
}{
|
||||
{"CreateObject", lfsCreateObject},
|
||||
{"GetObjectByOID", lfsGetObjectByOID},
|
||||
@ -48,7 +48,7 @@ func TestLFS(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func lfsCreateObject(t *testing.T, ctx context.Context, db *lfs) {
|
||||
func lfsCreateObject(t *testing.T, ctx context.Context, db *lfsStore) {
|
||||
// Create first LFS object
|
||||
repoID := int64(1)
|
||||
oid := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f")
|
||||
@ -65,7 +65,7 @@ func lfsCreateObject(t *testing.T, ctx context.Context, db *lfs) {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func lfsGetObjectByOID(t *testing.T, ctx context.Context, db *lfs) {
|
||||
func lfsGetObjectByOID(t *testing.T, ctx context.Context, db *lfsStore) {
|
||||
// Create a LFS object
|
||||
repoID := int64(1)
|
||||
oid := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f")
|
||||
@ -82,7 +82,7 @@ func lfsGetObjectByOID(t *testing.T, ctx context.Context, db *lfs) {
|
||||
assert.Equal(t, expErr, err)
|
||||
}
|
||||
|
||||
func lfsGetObjectsByOIDs(t *testing.T, ctx context.Context, db *lfs) {
|
||||
func lfsGetObjectsByOIDs(t *testing.T, ctx context.Context, db *lfsStore) {
|
||||
// Create two LFS objects
|
||||
repoID := int64(1)
|
||||
oid1 := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f")
|
||||
|
@ -180,9 +180,9 @@ func (s *LoginSource) GitHub() *github.Config {
|
||||
return s.Provider.Config().(*github.Config)
|
||||
}
|
||||
|
||||
var _ LoginSourcesStore = (*loginSources)(nil)
|
||||
var _ LoginSourcesStore = (*loginSourcesStore)(nil)
|
||||
|
||||
type loginSources struct {
|
||||
type loginSourcesStore struct {
|
||||
*gorm.DB
|
||||
files loginSourceFilesStore
|
||||
}
|
||||
@ -208,8 +208,8 @@ func (err ErrLoginSourceAlreadyExist) Error() string {
|
||||
return fmt.Sprintf("login source already exists: %v", err.args)
|
||||
}
|
||||
|
||||
func (db *loginSources) Create(ctx context.Context, opts CreateLoginSourceOptions) (*LoginSource, error) {
|
||||
err := db.WithContext(ctx).Where("name = ?", opts.Name).First(new(LoginSource)).Error
|
||||
func (s *loginSourcesStore) Create(ctx context.Context, opts CreateLoginSourceOptions) (*LoginSource, error) {
|
||||
err := s.WithContext(ctx).Where("name = ?", opts.Name).First(new(LoginSource)).Error
|
||||
if err == nil {
|
||||
return nil, ErrLoginSourceAlreadyExist{args: errutil.Args{"name": opts.Name}}
|
||||
} else if err != gorm.ErrRecordNotFound {
|
||||
@ -226,13 +226,13 @@ func (db *loginSources) Create(ctx context.Context, opts CreateLoginSourceOption
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return source, db.WithContext(ctx).Create(source).Error
|
||||
return source, s.WithContext(ctx).Create(source).Error
|
||||
}
|
||||
|
||||
func (db *loginSources) Count(ctx context.Context) int64 {
|
||||
func (s *loginSourcesStore) Count(ctx context.Context) int64 {
|
||||
var count int64
|
||||
db.WithContext(ctx).Model(new(LoginSource)).Count(&count)
|
||||
return count + int64(db.files.Len())
|
||||
s.WithContext(ctx).Model(new(LoginSource)).Count(&count)
|
||||
return count + int64(s.files.Len())
|
||||
}
|
||||
|
||||
type ErrLoginSourceInUse struct {
|
||||
@ -248,24 +248,24 @@ func (err ErrLoginSourceInUse) Error() string {
|
||||
return fmt.Sprintf("login source is still used by some users: %v", err.args)
|
||||
}
|
||||
|
||||
func (db *loginSources) DeleteByID(ctx context.Context, id int64) error {
|
||||
func (s *loginSourcesStore) DeleteByID(ctx context.Context, id int64) error {
|
||||
var count int64
|
||||
err := db.WithContext(ctx).Model(new(User)).Where("login_source = ?", id).Count(&count).Error
|
||||
err := s.WithContext(ctx).Model(new(User)).Where("login_source = ?", id).Count(&count).Error
|
||||
if err != nil {
|
||||
return err
|
||||
} else if count > 0 {
|
||||
return ErrLoginSourceInUse{args: errutil.Args{"id": id}}
|
||||
}
|
||||
|
||||
return db.WithContext(ctx).Where("id = ?", id).Delete(new(LoginSource)).Error
|
||||
return s.WithContext(ctx).Where("id = ?", id).Delete(new(LoginSource)).Error
|
||||
}
|
||||
|
||||
func (db *loginSources) GetByID(ctx context.Context, id int64) (*LoginSource, error) {
|
||||
func (s *loginSourcesStore) GetByID(ctx context.Context, id int64) (*LoginSource, error) {
|
||||
source := new(LoginSource)
|
||||
err := db.WithContext(ctx).Where("id = ?", id).First(source).Error
|
||||
err := s.WithContext(ctx).Where("id = ?", id).First(source).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return db.files.GetByID(id)
|
||||
return s.files.GetByID(id)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
@ -277,9 +277,9 @@ type ListLoginSourceOptions struct {
|
||||
OnlyActivated bool
|
||||
}
|
||||
|
||||
func (db *loginSources) List(ctx context.Context, opts ListLoginSourceOptions) ([]*LoginSource, error) {
|
||||
func (s *loginSourcesStore) List(ctx context.Context, opts ListLoginSourceOptions) ([]*LoginSource, error) {
|
||||
var sources []*LoginSource
|
||||
query := db.WithContext(ctx).Order("id ASC")
|
||||
query := s.WithContext(ctx).Order("id ASC")
|
||||
if opts.OnlyActivated {
|
||||
query = query.Where("is_actived = ?", true)
|
||||
}
|
||||
@ -288,11 +288,11 @@ func (db *loginSources) List(ctx context.Context, opts ListLoginSourceOptions) (
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return append(sources, db.files.List(opts)...), nil
|
||||
return append(sources, s.files.List(opts)...), nil
|
||||
}
|
||||
|
||||
func (db *loginSources) ResetNonDefault(ctx context.Context, dflt *LoginSource) error {
|
||||
err := db.WithContext(ctx).
|
||||
func (s *loginSourcesStore) ResetNonDefault(ctx context.Context, dflt *LoginSource) error {
|
||||
err := s.WithContext(ctx).
|
||||
Model(new(LoginSource)).
|
||||
Where("id != ?", dflt.ID).
|
||||
Updates(map[string]any{"is_default": false}).
|
||||
@ -301,7 +301,7 @@ func (db *loginSources) ResetNonDefault(ctx context.Context, dflt *LoginSource)
|
||||
return err
|
||||
}
|
||||
|
||||
for _, source := range db.files.List(ListLoginSourceOptions{}) {
|
||||
for _, source := range s.files.List(ListLoginSourceOptions{}) {
|
||||
if source.File != nil && source.ID != dflt.ID {
|
||||
source.File.SetGeneral("is_default", "false")
|
||||
if err = source.File.Save(); err != nil {
|
||||
@ -310,13 +310,13 @@ func (db *loginSources) ResetNonDefault(ctx context.Context, dflt *LoginSource)
|
||||
}
|
||||
}
|
||||
|
||||
db.files.Update(dflt)
|
||||
s.files.Update(dflt)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *loginSources) Save(ctx context.Context, source *LoginSource) error {
|
||||
func (s *loginSourcesStore) Save(ctx context.Context, source *LoginSource) error {
|
||||
if source.File == nil {
|
||||
return db.WithContext(ctx).Save(source).Error
|
||||
return s.WithContext(ctx).Save(source).Error
|
||||
}
|
||||
|
||||
source.File.SetGeneral("name", source.Name)
|
||||
|
@ -163,13 +163,13 @@ func TestLoginSources(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
db := &loginSources{
|
||||
DB: newTestDB(t, "loginSources"),
|
||||
db := &loginSourcesStore{
|
||||
DB: newTestDB(t, "loginSourcesStore"),
|
||||
}
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
test func(t *testing.T, ctx context.Context, db *loginSources)
|
||||
test func(t *testing.T, ctx context.Context, db *loginSourcesStore)
|
||||
}{
|
||||
{"Create", loginSourcesCreate},
|
||||
{"Count", loginSourcesCount},
|
||||
@ -192,7 +192,7 @@ func TestLoginSources(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func loginSourcesCreate(t *testing.T, ctx context.Context, db *loginSources) {
|
||||
func loginSourcesCreate(t *testing.T, ctx context.Context, db *loginSourcesStore) {
|
||||
// Create first login source with name "GitHub"
|
||||
source, err := db.Create(ctx,
|
||||
CreateLoginSourceOptions{
|
||||
@ -219,7 +219,7 @@ func loginSourcesCreate(t *testing.T, ctx context.Context, db *loginSources) {
|
||||
assert.Equal(t, wantErr, err)
|
||||
}
|
||||
|
||||
func loginSourcesCount(t *testing.T, ctx context.Context, db *loginSources) {
|
||||
func loginSourcesCount(t *testing.T, ctx context.Context, db *loginSourcesStore) {
|
||||
// Create two login sources, one in database and one as source file.
|
||||
_, err := db.Create(ctx,
|
||||
CreateLoginSourceOptions{
|
||||
@ -241,7 +241,7 @@ func loginSourcesCount(t *testing.T, ctx context.Context, db *loginSources) {
|
||||
assert.Equal(t, int64(3), db.Count(ctx))
|
||||
}
|
||||
|
||||
func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSources) {
|
||||
func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSourcesStore) {
|
||||
t.Run("delete but in used", func(t *testing.T) {
|
||||
source, err := db.Create(ctx,
|
||||
CreateLoginSourceOptions{
|
||||
@ -257,7 +257,7 @@ func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSources)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a user that uses this login source
|
||||
_, err = (&users{DB: db.DB}).Create(ctx, "alice", "",
|
||||
_, err = (&usersStore{DB: db.DB}).Create(ctx, "alice", "",
|
||||
CreateUserOptions{
|
||||
LoginSource: source.ID,
|
||||
},
|
||||
@ -308,7 +308,7 @@ func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSources)
|
||||
assert.Equal(t, wantErr, err)
|
||||
}
|
||||
|
||||
func loginSourcesGetByID(t *testing.T, ctx context.Context, db *loginSources) {
|
||||
func loginSourcesGetByID(t *testing.T, ctx context.Context, db *loginSourcesStore) {
|
||||
mock := NewMockLoginSourceFilesStore()
|
||||
mock.GetByIDFunc.SetDefaultHook(func(id int64) (*LoginSource, error) {
|
||||
if id != 101 {
|
||||
@ -344,7 +344,7 @@ func loginSourcesGetByID(t *testing.T, ctx context.Context, db *loginSources) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func loginSourcesList(t *testing.T, ctx context.Context, db *loginSources) {
|
||||
func loginSourcesList(t *testing.T, ctx context.Context, db *loginSourcesStore) {
|
||||
mock := NewMockLoginSourceFilesStore()
|
||||
mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource {
|
||||
if opts.OnlyActivated {
|
||||
@ -393,7 +393,7 @@ func loginSourcesList(t *testing.T, ctx context.Context, db *loginSources) {
|
||||
assert.Equal(t, 2, len(sources), "number of sources")
|
||||
}
|
||||
|
||||
func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSources) {
|
||||
func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSourcesStore) {
|
||||
mock := NewMockLoginSourceFilesStore()
|
||||
mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource {
|
||||
mockFile := NewMockLoginSourceFileStore()
|
||||
@ -448,7 +448,7 @@ func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSou
|
||||
assert.False(t, source2.IsDefault)
|
||||
}
|
||||
|
||||
func loginSourcesSave(t *testing.T, ctx context.Context, db *loginSources) {
|
||||
func loginSourcesSave(t *testing.T, ctx context.Context, db *loginSourcesStore) {
|
||||
t.Run("save to database", func(t *testing.T) {
|
||||
// Create a login source with name "GitHub"
|
||||
source, err := db.Create(ctx,
|
||||
|
@ -32,7 +32,7 @@ func setMockLoginSourcesStore(t *testing.T, mock LoginSourcesStore) {
|
||||
})
|
||||
}
|
||||
|
||||
func setMockLoginSourceFilesStore(t *testing.T, db *loginSources, mock loginSourceFilesStore) {
|
||||
func setMockLoginSourceFilesStore(t *testing.T, db *loginSourcesStore, mock loginSourceFilesStore) {
|
||||
before := db.files
|
||||
db.files = mock
|
||||
t.Cleanup(func() {
|
||||
|
@ -32,20 +32,20 @@ type NoticesStore interface {
|
||||
|
||||
var Notices NoticesStore
|
||||
|
||||
var _ NoticesStore = (*notices)(nil)
|
||||
var _ NoticesStore = (*noticesStore)(nil)
|
||||
|
||||
type notices struct {
|
||||
type noticesStore struct {
|
||||
*gorm.DB
|
||||
}
|
||||
|
||||
// NewNoticesStore returns a persistent interface for system notices with given
|
||||
// database connection.
|
||||
func NewNoticesStore(db *gorm.DB) NoticesStore {
|
||||
return ¬ices{DB: db}
|
||||
return ¬icesStore{DB: db}
|
||||
}
|
||||
|
||||
func (db *notices) Create(ctx context.Context, typ NoticeType, desc string) error {
|
||||
return db.WithContext(ctx).Create(
|
||||
func (s *noticesStore) Create(ctx context.Context, typ NoticeType, desc string) error {
|
||||
return s.WithContext(ctx).Create(
|
||||
&Notice{
|
||||
Type: typ,
|
||||
Description: desc,
|
||||
@ -53,26 +53,26 @@ func (db *notices) Create(ctx context.Context, typ NoticeType, desc string) erro
|
||||
).Error
|
||||
}
|
||||
|
||||
func (db *notices) DeleteByIDs(ctx context.Context, ids ...int64) error {
|
||||
return db.WithContext(ctx).Where("id IN (?)", ids).Delete(&Notice{}).Error
|
||||
func (s *noticesStore) DeleteByIDs(ctx context.Context, ids ...int64) error {
|
||||
return s.WithContext(ctx).Where("id IN (?)", ids).Delete(&Notice{}).Error
|
||||
}
|
||||
|
||||
func (db *notices) DeleteAll(ctx context.Context) error {
|
||||
return db.WithContext(ctx).Where("TRUE").Delete(&Notice{}).Error
|
||||
func (s *noticesStore) DeleteAll(ctx context.Context) error {
|
||||
return s.WithContext(ctx).Where("TRUE").Delete(&Notice{}).Error
|
||||
}
|
||||
|
||||
func (db *notices) List(ctx context.Context, page, pageSize int) ([]*Notice, error) {
|
||||
func (s *noticesStore) List(ctx context.Context, page, pageSize int) ([]*Notice, error) {
|
||||
notices := make([]*Notice, 0, pageSize)
|
||||
return notices, db.WithContext(ctx).
|
||||
return notices, s.WithContext(ctx).
|
||||
Limit(pageSize).Offset((page - 1) * pageSize).
|
||||
Order("id DESC").
|
||||
Find(¬ices).
|
||||
Error
|
||||
}
|
||||
|
||||
func (db *notices) Count(ctx context.Context) int64 {
|
||||
func (s *noticesStore) Count(ctx context.Context) int64 {
|
||||
var count int64
|
||||
db.WithContext(ctx).Model(&Notice{}).Count(&count)
|
||||
s.WithContext(ctx).Model(&Notice{}).Count(&count)
|
||||
return count
|
||||
}
|
||||
|
||||
|
@ -65,13 +65,13 @@ func TestNotices(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
db := ¬ices{
|
||||
DB: newTestDB(t, "notices"),
|
||||
db := ¬icesStore{
|
||||
DB: newTestDB(t, "noticesStore"),
|
||||
}
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
test func(t *testing.T, ctx context.Context, db *notices)
|
||||
test func(t *testing.T, ctx context.Context, db *noticesStore)
|
||||
}{
|
||||
{"Create", noticesCreate},
|
||||
{"DeleteByIDs", noticesDeleteByIDs},
|
||||
@ -92,7 +92,7 @@ func TestNotices(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func noticesCreate(t *testing.T, ctx context.Context, db *notices) {
|
||||
func noticesCreate(t *testing.T, ctx context.Context, db *noticesStore) {
|
||||
err := db.Create(ctx, NoticeTypeRepository, "test")
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -100,7 +100,7 @@ func noticesCreate(t *testing.T, ctx context.Context, db *notices) {
|
||||
assert.Equal(t, int64(1), count)
|
||||
}
|
||||
|
||||
func noticesDeleteByIDs(t *testing.T, ctx context.Context, db *notices) {
|
||||
func noticesDeleteByIDs(t *testing.T, ctx context.Context, db *noticesStore) {
|
||||
err := db.Create(ctx, NoticeTypeRepository, "test")
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -120,7 +120,7 @@ func noticesDeleteByIDs(t *testing.T, ctx context.Context, db *notices) {
|
||||
assert.Equal(t, int64(0), count)
|
||||
}
|
||||
|
||||
func noticesDeleteAll(t *testing.T, ctx context.Context, db *notices) {
|
||||
func noticesDeleteAll(t *testing.T, ctx context.Context, db *noticesStore) {
|
||||
err := db.Create(ctx, NoticeTypeRepository, "test")
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -131,7 +131,7 @@ func noticesDeleteAll(t *testing.T, ctx context.Context, db *notices) {
|
||||
assert.Equal(t, int64(0), count)
|
||||
}
|
||||
|
||||
func noticesList(t *testing.T, ctx context.Context, db *notices) {
|
||||
func noticesList(t *testing.T, ctx context.Context, db *noticesStore) {
|
||||
err := db.Create(ctx, NoticeTypeRepository, "test 1")
|
||||
require.NoError(t, err)
|
||||
err = db.Create(ctx, NoticeTypeRepository, "test 2")
|
||||
@ -151,7 +151,7 @@ func noticesList(t *testing.T, ctx context.Context, db *notices) {
|
||||
require.Len(t, got, 2)
|
||||
}
|
||||
|
||||
func noticesCount(t *testing.T, ctx context.Context, db *notices) {
|
||||
func noticesCount(t *testing.T, ctx context.Context, db *noticesStore) {
|
||||
count := db.Count(ctx)
|
||||
assert.Equal(t, int64(0), count)
|
||||
|
||||
|
@ -30,16 +30,16 @@ type OrgsStore interface {
|
||||
|
||||
var Orgs OrgsStore
|
||||
|
||||
var _ OrgsStore = (*orgs)(nil)
|
||||
var _ OrgsStore = (*orgsStore)(nil)
|
||||
|
||||
type orgs struct {
|
||||
type orgsStore struct {
|
||||
*gorm.DB
|
||||
}
|
||||
|
||||
// NewOrgsStore returns a persistent interface for orgs with given database
|
||||
// connection.
|
||||
func NewOrgsStore(db *gorm.DB) OrgsStore {
|
||||
return &orgs{DB: db}
|
||||
return &orgsStore{DB: db}
|
||||
}
|
||||
|
||||
type ListOrgsOptions struct {
|
||||
@ -49,7 +49,7 @@ type ListOrgsOptions struct {
|
||||
IncludePrivateMembers bool
|
||||
}
|
||||
|
||||
func (db *orgs) List(ctx context.Context, opts ListOrgsOptions) ([]*Organization, error) {
|
||||
func (s *orgsStore) List(ctx context.Context, opts ListOrgsOptions) ([]*Organization, error) {
|
||||
if opts.MemberID <= 0 {
|
||||
return nil, errors.New("MemberID must be greater than 0")
|
||||
}
|
||||
@ -64,7 +64,7 @@ func (db *orgs) List(ctx context.Context, opts ListOrgsOptions) ([]*Organization
|
||||
[AND org_user.is_public = @includePrivateMembers]
|
||||
ORDER BY org.id ASC
|
||||
*/
|
||||
tx := db.WithContext(ctx).
|
||||
tx := s.WithContext(ctx).
|
||||
Joins(dbutil.Quote("JOIN org_user ON org_user.org_id = %s.id", "user")).
|
||||
Where("org_user.uid = ?", opts.MemberID).
|
||||
Order(dbutil.Quote("%s.id ASC", "user"))
|
||||
@ -76,13 +76,13 @@ func (db *orgs) List(ctx context.Context, opts ListOrgsOptions) ([]*Organization
|
||||
return orgs, tx.Find(&orgs).Error
|
||||
}
|
||||
|
||||
func (db *orgs) SearchByName(ctx context.Context, keyword string, page, pageSize int, orderBy string) ([]*Organization, int64, error) {
|
||||
return searchUserByName(ctx, db.DB, UserTypeOrganization, keyword, page, pageSize, orderBy)
|
||||
func (s *orgsStore) SearchByName(ctx context.Context, keyword string, page, pageSize int, orderBy string) ([]*Organization, int64, error) {
|
||||
return searchUserByName(ctx, s.DB, UserTypeOrganization, keyword, page, pageSize, orderBy)
|
||||
}
|
||||
|
||||
func (db *orgs) CountByUser(ctx context.Context, userID int64) (int64, error) {
|
||||
func (s *orgsStore) CountByUser(ctx context.Context, userID int64) (int64, error) {
|
||||
var count int64
|
||||
return count, db.WithContext(ctx).Model(&OrgUser{}).Where("uid = ?", userID).Count(&count).Error
|
||||
return count, s.WithContext(ctx).Model(&OrgUser{}).Where("uid = ?", userID).Count(&count).Error
|
||||
}
|
||||
|
||||
type Organization = User
|
||||
|
@ -21,13 +21,13 @@ func TestOrgs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
db := &orgs{
|
||||
DB: newTestDB(t, "orgs"),
|
||||
db := &orgsStore{
|
||||
DB: newTestDB(t, "orgsStore"),
|
||||
}
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
test func(t *testing.T, ctx context.Context, db *orgs)
|
||||
test func(t *testing.T, ctx context.Context, db *orgsStore)
|
||||
}{
|
||||
{"List", orgsList},
|
||||
{"SearchByName", orgsSearchByName},
|
||||
@ -46,7 +46,7 @@ func TestOrgs(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func orgsList(t *testing.T, ctx context.Context, db *orgs) {
|
||||
func orgsList(t *testing.T, ctx context.Context, db *orgsStore) {
|
||||
usersStore := NewUsersStore(db.DB)
|
||||
alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
|
||||
require.NoError(t, err)
|
||||
@ -116,7 +116,7 @@ func orgsList(t *testing.T, ctx context.Context, db *orgs) {
|
||||
}
|
||||
}
|
||||
|
||||
func orgsSearchByName(t *testing.T, ctx context.Context, db *orgs) {
|
||||
func orgsSearchByName(t *testing.T, ctx context.Context, db *orgsStore) {
|
||||
// TODO: Use Orgs.Create to replace SQL hack when the method is available.
|
||||
usersStore := NewUsersStore(db.DB)
|
||||
org1, err := usersStore.Create(ctx, "org1", "org1@example.com", CreateUserOptions{FullName: "Acme Corp"})
|
||||
@ -161,7 +161,7 @@ func orgsSearchByName(t *testing.T, ctx context.Context, db *orgs) {
|
||||
})
|
||||
}
|
||||
|
||||
func orgsCountByUser(t *testing.T, ctx context.Context, db *orgs) {
|
||||
func orgsCountByUser(t *testing.T, ctx context.Context, db *orgsStore) {
|
||||
// TODO: Use Orgs.Join to replace SQL hack when the method is available.
|
||||
err := db.Exec(`INSERT INTO org_user (uid, org_id) VALUES (?, ?)`, 1, 1).Error
|
||||
require.NoError(t, err)
|
||||
|
@ -74,16 +74,16 @@ func ParseAccessMode(permission string) AccessMode {
|
||||
}
|
||||
}
|
||||
|
||||
var _ PermsStore = (*perms)(nil)
|
||||
var _ PermsStore = (*permsStore)(nil)
|
||||
|
||||
type perms struct {
|
||||
type permsStore struct {
|
||||
*gorm.DB
|
||||
}
|
||||
|
||||
// NewPermsStore returns a persistent interface for permissions with given
|
||||
// database connection.
|
||||
func NewPermsStore(db *gorm.DB) PermsStore {
|
||||
return &perms{DB: db}
|
||||
return &permsStore{DB: db}
|
||||
}
|
||||
|
||||
type AccessModeOptions struct {
|
||||
@ -91,7 +91,7 @@ type AccessModeOptions struct {
|
||||
Private bool // Whether the repository is private.
|
||||
}
|
||||
|
||||
func (db *perms) AccessMode(ctx context.Context, userID, repoID int64, opts AccessModeOptions) (mode AccessMode) {
|
||||
func (s *permsStore) AccessMode(ctx context.Context, userID, repoID int64, opts AccessModeOptions) (mode AccessMode) {
|
||||
if repoID <= 0 {
|
||||
return AccessModeNone
|
||||
}
|
||||
@ -111,7 +111,7 @@ func (db *perms) AccessMode(ctx context.Context, userID, repoID int64, opts Acce
|
||||
}
|
||||
|
||||
access := new(Access)
|
||||
err := db.WithContext(ctx).Where("user_id = ? AND repo_id = ?", userID, repoID).First(access).Error
|
||||
err := s.WithContext(ctx).Where("user_id = ? AND repo_id = ?", userID, repoID).First(access).Error
|
||||
if err != nil {
|
||||
if err != gorm.ErrRecordNotFound {
|
||||
log.Error("Failed to get access [user_id: %d, repo_id: %d]: %v", userID, repoID, err)
|
||||
@ -121,11 +121,11 @@ func (db *perms) AccessMode(ctx context.Context, userID, repoID int64, opts Acce
|
||||
return access.Mode
|
||||
}
|
||||
|
||||
func (db *perms) Authorize(ctx context.Context, userID, repoID int64, desired AccessMode, opts AccessModeOptions) bool {
|
||||
return desired <= db.AccessMode(ctx, userID, repoID, opts)
|
||||
func (s *permsStore) Authorize(ctx context.Context, userID, repoID int64, desired AccessMode, opts AccessModeOptions) bool {
|
||||
return desired <= s.AccessMode(ctx, userID, repoID, opts)
|
||||
}
|
||||
|
||||
func (db *perms) SetRepoPerms(ctx context.Context, repoID int64, accessMap map[int64]AccessMode) error {
|
||||
func (s *permsStore) SetRepoPerms(ctx context.Context, repoID int64, accessMap map[int64]AccessMode) error {
|
||||
records := make([]*Access, 0, len(accessMap))
|
||||
for userID, mode := range accessMap {
|
||||
records = append(records, &Access{
|
||||
@ -135,7 +135,7 @@ func (db *perms) SetRepoPerms(ctx context.Context, repoID int64, accessMap map[i
|
||||
})
|
||||
}
|
||||
|
||||
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
err := tx.Where("repo_id = ?", repoID).Delete(new(Access)).Error
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -19,13 +19,13 @@ func TestPerms(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
db := &perms{
|
||||
DB: newTestDB(t, "perms"),
|
||||
db := &permsStore{
|
||||
DB: newTestDB(t, "permsStore"),
|
||||
}
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
test func(t *testing.T, ctx context.Context, db *perms)
|
||||
test func(t *testing.T, ctx context.Context, db *permsStore)
|
||||
}{
|
||||
{"AccessMode", permsAccessMode},
|
||||
{"Authorize", permsAuthorize},
|
||||
@ -44,7 +44,7 @@ func TestPerms(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func permsAccessMode(t *testing.T, ctx context.Context, db *perms) {
|
||||
func permsAccessMode(t *testing.T, ctx context.Context, db *permsStore) {
|
||||
// Set up permissions
|
||||
err := db.SetRepoPerms(ctx, 1,
|
||||
map[int64]AccessMode{
|
||||
@ -155,7 +155,7 @@ func permsAccessMode(t *testing.T, ctx context.Context, db *perms) {
|
||||
}
|
||||
}
|
||||
|
||||
func permsAuthorize(t *testing.T, ctx context.Context, db *perms) {
|
||||
func permsAuthorize(t *testing.T, ctx context.Context, db *permsStore) {
|
||||
// Set up permissions
|
||||
err := db.SetRepoPerms(ctx, 1,
|
||||
map[int64]AccessMode{
|
||||
@ -241,7 +241,7 @@ func permsAuthorize(t *testing.T, ctx context.Context, db *perms) {
|
||||
}
|
||||
}
|
||||
|
||||
func permsSetRepoPerms(t *testing.T, ctx context.Context, db *perms) {
|
||||
func permsSetRepoPerms(t *testing.T, ctx context.Context, db *permsStore) {
|
||||
for _, update := range []struct {
|
||||
repoID int64
|
||||
accessMap map[int64]AccessMode
|
||||
|
@ -24,23 +24,23 @@ type PublicKeysStore interface {
|
||||
|
||||
var PublicKeys PublicKeysStore
|
||||
|
||||
var _ PublicKeysStore = (*publicKeys)(nil)
|
||||
var _ PublicKeysStore = (*publicKeysStore)(nil)
|
||||
|
||||
type publicKeys struct {
|
||||
type publicKeysStore struct {
|
||||
*gorm.DB
|
||||
}
|
||||
|
||||
// NewPublicKeysStore returns a persistent interface for public keys with given
|
||||
// database connection.
|
||||
func NewPublicKeysStore(db *gorm.DB) PublicKeysStore {
|
||||
return &publicKeys{DB: db}
|
||||
return &publicKeysStore{DB: db}
|
||||
}
|
||||
|
||||
func authorizedKeysPath() string {
|
||||
return filepath.Join(conf.SSH.RootPath, "authorized_keys")
|
||||
}
|
||||
|
||||
func (db *publicKeys) RewriteAuthorizedKeys() error {
|
||||
func (s *publicKeysStore) RewriteAuthorizedKeys() error {
|
||||
sshOpLocker.Lock()
|
||||
defer sshOpLocker.Unlock()
|
||||
|
||||
@ -61,7 +61,7 @@ func (db *publicKeys) RewriteAuthorizedKeys() error {
|
||||
|
||||
// NOTE: More recently updated keys are more likely to be used more frequently,
|
||||
// putting them in the earlier lines could speed up the key lookup by SSHD.
|
||||
rows, err := db.Model(&PublicKey{}).Order("updated_unix DESC").Rows()
|
||||
rows, err := s.Model(&PublicKey{}).Order("updated_unix DESC").Rows()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "iterate public keys")
|
||||
}
|
||||
@ -69,7 +69,7 @@ func (db *publicKeys) RewriteAuthorizedKeys() error {
|
||||
|
||||
for rows.Next() {
|
||||
var key PublicKey
|
||||
err = db.ScanRows(rows, &key)
|
||||
err = s.ScanRows(rows, &key)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "scan rows")
|
||||
}
|
||||
|
@ -24,13 +24,13 @@ func TestPublicKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
db := &publicKeys{
|
||||
DB: newTestDB(t, "publicKeys"),
|
||||
db := &publicKeysStore{
|
||||
DB: newTestDB(t, "publicKeysStore"),
|
||||
}
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
test func(t *testing.T, ctx context.Context, db *publicKeys)
|
||||
test func(t *testing.T, ctx context.Context, db *publicKeysStore)
|
||||
}{
|
||||
{"RewriteAuthorizedKeys", publicKeysRewriteAuthorizedKeys},
|
||||
} {
|
||||
@ -47,7 +47,7 @@ func TestPublicKeys(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func publicKeysRewriteAuthorizedKeys(t *testing.T, ctx context.Context, db *publicKeys) {
|
||||
func publicKeysRewriteAuthorizedKeys(t *testing.T, ctx context.Context, db *publicKeysStore) {
|
||||
// TODO: Use PublicKeys.Add to replace SQL hack when the method is available.
|
||||
publicKey := &PublicKey{
|
||||
OwnerID: 1,
|
||||
|
@ -119,16 +119,16 @@ func (r *Repository) APIFormat(owner *User, opts ...RepositoryAPIFormatOptions)
|
||||
}
|
||||
}
|
||||
|
||||
var _ ReposStore = (*repos)(nil)
|
||||
var _ ReposStore = (*reposStore)(nil)
|
||||
|
||||
type repos struct {
|
||||
type reposStore struct {
|
||||
*gorm.DB
|
||||
}
|
||||
|
||||
// NewReposStore returns a persistent interface for repositories with given
|
||||
// database connection.
|
||||
func NewReposStore(db *gorm.DB) ReposStore {
|
||||
return &repos{DB: db}
|
||||
return &reposStore{DB: db}
|
||||
}
|
||||
|
||||
type ErrRepoAlreadyExist struct {
|
||||
@ -157,13 +157,13 @@ type CreateRepoOptions struct {
|
||||
ForkID int64
|
||||
}
|
||||
|
||||
func (db *repos) Create(ctx context.Context, ownerID int64, opts CreateRepoOptions) (*Repository, error) {
|
||||
func (s *reposStore) Create(ctx context.Context, ownerID int64, opts CreateRepoOptions) (*Repository, error) {
|
||||
err := isRepoNameAllowed(opts.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = db.GetByName(ctx, ownerID, opts.Name)
|
||||
_, err = s.GetByName(ctx, ownerID, opts.Name)
|
||||
if err == nil {
|
||||
return nil, ErrRepoAlreadyExist{
|
||||
args: errutil.Args{
|
||||
@ -189,7 +189,7 @@ func (db *repos) Create(ctx context.Context, ownerID int64, opts CreateRepoOptio
|
||||
IsFork: opts.Fork,
|
||||
ForkID: opts.ForkID,
|
||||
}
|
||||
return repo, db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
return repo, s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
err = tx.Create(repo).Error
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "create")
|
||||
@ -203,7 +203,7 @@ func (db *repos) Create(ctx context.Context, ownerID int64, opts CreateRepoOptio
|
||||
})
|
||||
}
|
||||
|
||||
func (db *repos) GetByCollaboratorID(ctx context.Context, collaboratorID int64, limit int, orderBy string) ([]*Repository, error) {
|
||||
func (s *reposStore) GetByCollaboratorID(ctx context.Context, collaboratorID int64, limit int, orderBy string) ([]*Repository, error) {
|
||||
/*
|
||||
Equivalent SQL for PostgreSQL:
|
||||
|
||||
@ -214,7 +214,7 @@ func (db *repos) GetByCollaboratorID(ctx context.Context, collaboratorID int64,
|
||||
LIMIT @limit
|
||||
*/
|
||||
var repos []*Repository
|
||||
return repos, db.WithContext(ctx).
|
||||
return repos, s.WithContext(ctx).
|
||||
Joins("JOIN access ON access.repo_id = repository.id AND access.user_id = ?", collaboratorID).
|
||||
Where("access.mode >= ?", AccessModeRead).
|
||||
Order(orderBy).
|
||||
@ -223,7 +223,7 @@ func (db *repos) GetByCollaboratorID(ctx context.Context, collaboratorID int64,
|
||||
Error
|
||||
}
|
||||
|
||||
func (db *repos) GetByCollaboratorIDWithAccessMode(ctx context.Context, collaboratorID int64) (map[*Repository]AccessMode, error) {
|
||||
func (s *reposStore) GetByCollaboratorIDWithAccessMode(ctx context.Context, collaboratorID int64) (map[*Repository]AccessMode, error) {
|
||||
/*
|
||||
Equivalent SQL for PostgreSQL:
|
||||
|
||||
@ -238,7 +238,7 @@ func (db *repos) GetByCollaboratorIDWithAccessMode(ctx context.Context, collabor
|
||||
*Repository
|
||||
Mode AccessMode
|
||||
}
|
||||
err := db.WithContext(ctx).
|
||||
err := s.WithContext(ctx).
|
||||
Select("repository.*", "access.mode").
|
||||
Table("repository").
|
||||
Joins("JOIN access ON access.repo_id = repository.id AND access.user_id = ?", collaboratorID).
|
||||
@ -275,9 +275,9 @@ func (ErrRepoNotExist) NotFound() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (db *repos) GetByID(ctx context.Context, id int64) (*Repository, error) {
|
||||
func (s *reposStore) GetByID(ctx context.Context, id int64) (*Repository, error) {
|
||||
repo := new(Repository)
|
||||
err := db.WithContext(ctx).Where("id = ?", id).First(repo).Error
|
||||
err := s.WithContext(ctx).Where("id = ?", id).First(repo).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, ErrRepoNotExist{errutil.Args{"repoID": id}}
|
||||
@ -287,9 +287,9 @@ func (db *repos) GetByID(ctx context.Context, id int64) (*Repository, error) {
|
||||
return repo, nil
|
||||
}
|
||||
|
||||
func (db *repos) GetByName(ctx context.Context, ownerID int64, name string) (*Repository, error) {
|
||||
func (s *reposStore) GetByName(ctx context.Context, ownerID int64, name string) (*Repository, error) {
|
||||
repo := new(Repository)
|
||||
err := db.WithContext(ctx).
|
||||
err := s.WithContext(ctx).
|
||||
Where("owner_id = ? AND lower_name = ?", ownerID, strings.ToLower(name)).
|
||||
First(repo).
|
||||
Error
|
||||
@ -307,7 +307,7 @@ func (db *repos) GetByName(ctx context.Context, ownerID int64, name string) (*Re
|
||||
return repo, nil
|
||||
}
|
||||
|
||||
func (db *repos) recountStars(tx *gorm.DB, userID, repoID int64) error {
|
||||
func (s *reposStore) recountStars(tx *gorm.DB, userID, repoID int64) error {
|
||||
/*
|
||||
Equivalent SQL for PostgreSQL:
|
||||
|
||||
@ -350,40 +350,40 @@ func (db *repos) recountStars(tx *gorm.DB, userID, repoID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *repos) Star(ctx context.Context, userID, repoID int64) error {
|
||||
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
s := &Star{
|
||||
func (s *reposStore) Star(ctx context.Context, userID, repoID int64) error {
|
||||
return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
star := &Star{
|
||||
UserID: userID,
|
||||
RepoID: repoID,
|
||||
}
|
||||
result := tx.FirstOrCreate(s, s)
|
||||
result := tx.FirstOrCreate(star, star)
|
||||
if result.Error != nil {
|
||||
return errors.Wrap(result.Error, "upsert")
|
||||
} else if result.RowsAffected <= 0 {
|
||||
return nil // Relation already exists
|
||||
}
|
||||
|
||||
return db.recountStars(tx, userID, repoID)
|
||||
return s.recountStars(tx, userID, repoID)
|
||||
})
|
||||
}
|
||||
|
||||
func (db *repos) Touch(ctx context.Context, id int64) error {
|
||||
return db.WithContext(ctx).
|
||||
func (s *reposStore) Touch(ctx context.Context, id int64) error {
|
||||
return s.WithContext(ctx).
|
||||
Model(new(Repository)).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
"is_bare": false,
|
||||
"updated_unix": db.NowFunc().Unix(),
|
||||
"updated_unix": s.NowFunc().Unix(),
|
||||
}).
|
||||
Error
|
||||
}
|
||||
|
||||
func (db *repos) ListWatches(ctx context.Context, repoID int64) ([]*Watch, error) {
|
||||
func (s *reposStore) ListWatches(ctx context.Context, repoID int64) ([]*Watch, error) {
|
||||
var watches []*Watch
|
||||
return watches, db.WithContext(ctx).Where("repo_id = ?", repoID).Find(&watches).Error
|
||||
return watches, s.WithContext(ctx).Where("repo_id = ?", repoID).Find(&watches).Error
|
||||
}
|
||||
|
||||
func (db *repos) recountWatches(tx *gorm.DB, repoID int64) error {
|
||||
func (s *reposStore) recountWatches(tx *gorm.DB, repoID int64) error {
|
||||
/*
|
||||
Equivalent SQL for PostgreSQL:
|
||||
|
||||
@ -402,8 +402,8 @@ func (db *repos) recountWatches(tx *gorm.DB, repoID int64) error {
|
||||
Error
|
||||
}
|
||||
|
||||
func (db *repos) Watch(ctx context.Context, userID, repoID int64) error {
|
||||
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
func (s *reposStore) Watch(ctx context.Context, userID, repoID int64) error {
|
||||
return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
w := &Watch{
|
||||
UserID: userID,
|
||||
RepoID: repoID,
|
||||
@ -415,12 +415,12 @@ func (db *repos) Watch(ctx context.Context, userID, repoID int64) error {
|
||||
return nil // Relation already exists
|
||||
}
|
||||
|
||||
return db.recountWatches(tx, repoID)
|
||||
return s.recountWatches(tx, repoID)
|
||||
})
|
||||
}
|
||||
|
||||
func (db *repos) HasForkedBy(ctx context.Context, repoID, userID int64) bool {
|
||||
func (s *reposStore) HasForkedBy(ctx context.Context, repoID, userID int64) bool {
|
||||
var count int64
|
||||
db.WithContext(ctx).Model(new(Repository)).Where("owner_id = ? AND fork_id = ?", userID, repoID).Count(&count)
|
||||
s.WithContext(ctx).Model(new(Repository)).Where("owner_id = ? AND fork_id = ?", userID, repoID).Count(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
@ -85,13 +85,13 @@ func TestRepos(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
db := &repos{
|
||||
db := &reposStore{
|
||||
DB: newTestDB(t, "repos"),
|
||||
}
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
test func(t *testing.T, ctx context.Context, db *repos)
|
||||
test func(t *testing.T, ctx context.Context, db *reposStore)
|
||||
}{
|
||||
{"Create", reposCreate},
|
||||
{"GetByCollaboratorID", reposGetByCollaboratorID},
|
||||
@ -117,7 +117,7 @@ func TestRepos(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func reposCreate(t *testing.T, ctx context.Context, db *repos) {
|
||||
func reposCreate(t *testing.T, ctx context.Context, db *reposStore) {
|
||||
t.Run("name not allowed", func(t *testing.T) {
|
||||
_, err := db.Create(ctx,
|
||||
1,
|
||||
@ -159,7 +159,7 @@ func reposCreate(t *testing.T, ctx context.Context, db *repos) {
|
||||
assert.Equal(t, 1, repo.NumWatches) // The owner is watching the repo by default.
|
||||
}
|
||||
|
||||
func reposGetByCollaboratorID(t *testing.T, ctx context.Context, db *repos) {
|
||||
func reposGetByCollaboratorID(t *testing.T, ctx context.Context, db *reposStore) {
|
||||
repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"})
|
||||
require.NoError(t, err)
|
||||
repo2, err := db.Create(ctx, 2, CreateRepoOptions{Name: "repo2"})
|
||||
@ -185,7 +185,7 @@ func reposGetByCollaboratorID(t *testing.T, ctx context.Context, db *repos) {
|
||||
})
|
||||
}
|
||||
|
||||
func reposGetByCollaboratorIDWithAccessMode(t *testing.T, ctx context.Context, db *repos) {
|
||||
func reposGetByCollaboratorIDWithAccessMode(t *testing.T, ctx context.Context, db *reposStore) {
|
||||
repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"})
|
||||
require.NoError(t, err)
|
||||
repo2, err := db.Create(ctx, 2, CreateRepoOptions{Name: "repo2"})
|
||||
@ -213,7 +213,7 @@ func reposGetByCollaboratorIDWithAccessMode(t *testing.T, ctx context.Context, d
|
||||
assert.Equal(t, AccessModeAdmin, accessModes[repo2.ID])
|
||||
}
|
||||
|
||||
func reposGetByID(t *testing.T, ctx context.Context, db *repos) {
|
||||
func reposGetByID(t *testing.T, ctx context.Context, db *reposStore) {
|
||||
repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -226,7 +226,7 @@ func reposGetByID(t *testing.T, ctx context.Context, db *repos) {
|
||||
assert.Equal(t, wantErr, err)
|
||||
}
|
||||
|
||||
func reposGetByName(t *testing.T, ctx context.Context, db *repos) {
|
||||
func reposGetByName(t *testing.T, ctx context.Context, db *reposStore) {
|
||||
repo, err := db.Create(ctx, 1,
|
||||
CreateRepoOptions{
|
||||
Name: "repo1",
|
||||
@ -242,7 +242,7 @@ func reposGetByName(t *testing.T, ctx context.Context, db *repos) {
|
||||
assert.Equal(t, wantErr, err)
|
||||
}
|
||||
|
||||
func reposStar(t *testing.T, ctx context.Context, db *repos) {
|
||||
func reposStar(t *testing.T, ctx context.Context, db *reposStore) {
|
||||
repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"})
|
||||
require.NoError(t, err)
|
||||
usersStore := NewUsersStore(db.DB)
|
||||
@ -261,7 +261,7 @@ func reposStar(t *testing.T, ctx context.Context, db *repos) {
|
||||
assert.Equal(t, 1, alice.NumStars)
|
||||
}
|
||||
|
||||
func reposTouch(t *testing.T, ctx context.Context, db *repos) {
|
||||
func reposTouch(t *testing.T, ctx context.Context, db *reposStore) {
|
||||
repo, err := db.Create(ctx, 1,
|
||||
CreateRepoOptions{
|
||||
Name: "repo1",
|
||||
@ -287,7 +287,7 @@ func reposTouch(t *testing.T, ctx context.Context, db *repos) {
|
||||
assert.False(t, got.IsBare)
|
||||
}
|
||||
|
||||
func reposListWatches(t *testing.T, ctx context.Context, db *repos) {
|
||||
func reposListWatches(t *testing.T, ctx context.Context, db *reposStore) {
|
||||
err := db.Watch(ctx, 1, 1)
|
||||
require.NoError(t, err)
|
||||
err = db.Watch(ctx, 2, 1)
|
||||
@ -308,7 +308,7 @@ func reposListWatches(t *testing.T, ctx context.Context, db *repos) {
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func reposWatch(t *testing.T, ctx context.Context, db *repos) {
|
||||
func reposWatch(t *testing.T, ctx context.Context, db *reposStore) {
|
||||
reposStore := NewReposStore(db.DB)
|
||||
repo1, err := reposStore.Create(ctx, 1, CreateRepoOptions{Name: "repo1"})
|
||||
require.NoError(t, err)
|
||||
@ -325,7 +325,7 @@ func reposWatch(t *testing.T, ctx context.Context, db *repos) {
|
||||
assert.Equal(t, 2, repo1.NumWatches) // The owner is watching the repo by default.
|
||||
}
|
||||
|
||||
func reposHasForkedBy(t *testing.T, ctx context.Context, db *repos) {
|
||||
func reposHasForkedBy(t *testing.T, ctx context.Context, db *reposStore) {
|
||||
has := db.HasForkedBy(ctx, 1, 2)
|
||||
assert.False(t, has)
|
||||
|
||||
|
@ -50,13 +50,13 @@ func (t *TwoFactor) AfterFind(_ *gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ TwoFactorsStore = (*twoFactors)(nil)
|
||||
var _ TwoFactorsStore = (*twoFactorsStore)(nil)
|
||||
|
||||
type twoFactors struct {
|
||||
type twoFactorsStore struct {
|
||||
*gorm.DB
|
||||
}
|
||||
|
||||
func (db *twoFactors) Create(ctx context.Context, userID int64, key, secret string) error {
|
||||
func (s *twoFactorsStore) Create(ctx context.Context, userID int64, key, secret string) error {
|
||||
encrypted, err := cryptoutil.AESGCMEncrypt(cryptoutil.MD5Bytes(key), []byte(secret))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "encrypt secret")
|
||||
@ -71,7 +71,7 @@ func (db *twoFactors) Create(ctx context.Context, userID int64, key, secret stri
|
||||
return errors.Wrap(err, "generate recovery codes")
|
||||
}
|
||||
|
||||
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
err := tx.Create(tf).Error
|
||||
if err != nil {
|
||||
return err
|
||||
@ -100,9 +100,9 @@ func (ErrTwoFactorNotFound) NotFound() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (db *twoFactors) GetByUserID(ctx context.Context, userID int64) (*TwoFactor, error) {
|
||||
func (s *twoFactorsStore) GetByUserID(ctx context.Context, userID int64) (*TwoFactor, error) {
|
||||
tf := new(TwoFactor)
|
||||
err := db.WithContext(ctx).Where("user_id = ?", userID).First(tf).Error
|
||||
err := s.WithContext(ctx).Where("user_id = ?", userID).First(tf).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, ErrTwoFactorNotFound{args: errutil.Args{"userID": userID}}
|
||||
@ -112,9 +112,9 @@ func (db *twoFactors) GetByUserID(ctx context.Context, userID int64) (*TwoFactor
|
||||
return tf, nil
|
||||
}
|
||||
|
||||
func (db *twoFactors) IsEnabled(ctx context.Context, userID int64) bool {
|
||||
func (s *twoFactorsStore) IsEnabled(ctx context.Context, userID int64) bool {
|
||||
var count int64
|
||||
err := db.WithContext(ctx).Model(new(TwoFactor)).Where("user_id = ?", userID).Count(&count).Error
|
||||
err := s.WithContext(ctx).Model(new(TwoFactor)).Where("user_id = ?", userID).Count(&count).Error
|
||||
if err != nil {
|
||||
log.Error("Failed to count two factors [user_id: %d]: %v", userID, err)
|
||||
}
|
||||
|
@ -67,13 +67,13 @@ func TestTwoFactors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
db := &twoFactors{
|
||||
DB: newTestDB(t, "twoFactors"),
|
||||
db := &twoFactorsStore{
|
||||
DB: newTestDB(t, "twoFactorsStore"),
|
||||
}
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
test func(t *testing.T, ctx context.Context, db *twoFactors)
|
||||
test func(t *testing.T, ctx context.Context, db *twoFactorsStore)
|
||||
}{
|
||||
{"Create", twoFactorsCreate},
|
||||
{"GetByUserID", twoFactorsGetByUserID},
|
||||
@ -92,7 +92,7 @@ func TestTwoFactors(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func twoFactorsCreate(t *testing.T, ctx context.Context, db *twoFactors) {
|
||||
func twoFactorsCreate(t *testing.T, ctx context.Context, db *twoFactorsStore) {
|
||||
// Create a 2FA token
|
||||
err := db.Create(ctx, 1, "secure-key", "secure-secret")
|
||||
require.NoError(t, err)
|
||||
@ -109,7 +109,7 @@ func twoFactorsCreate(t *testing.T, ctx context.Context, db *twoFactors) {
|
||||
assert.Equal(t, int64(10), count)
|
||||
}
|
||||
|
||||
func twoFactorsGetByUserID(t *testing.T, ctx context.Context, db *twoFactors) {
|
||||
func twoFactorsGetByUserID(t *testing.T, ctx context.Context, db *twoFactorsStore) {
|
||||
// Create a 2FA token for user 1
|
||||
err := db.Create(ctx, 1, "secure-key", "secure-secret")
|
||||
require.NoError(t, err)
|
||||
@ -124,7 +124,7 @@ func twoFactorsGetByUserID(t *testing.T, ctx context.Context, db *twoFactors) {
|
||||
assert.Equal(t, wantErr, err)
|
||||
}
|
||||
|
||||
func twoFactorsIsEnabled(t *testing.T, ctx context.Context, db *twoFactors) {
|
||||
func twoFactorsIsEnabled(t *testing.T, ctx context.Context, db *twoFactorsStore) {
|
||||
// Create a 2FA token for user 1
|
||||
err := db.Create(ctx, 1, "secure-key", "secure-secret")
|
||||
require.NoError(t, err)
|
||||
|
@ -146,16 +146,16 @@ type UsersStore interface {
|
||||
|
||||
var Users UsersStore
|
||||
|
||||
var _ UsersStore = (*users)(nil)
|
||||
var _ UsersStore = (*usersStore)(nil)
|
||||
|
||||
type users struct {
|
||||
type usersStore struct {
|
||||
*gorm.DB
|
||||
}
|
||||
|
||||
// NewUsersStore returns a persistent interface for users with given database
|
||||
// connection.
|
||||
func NewUsersStore(db *gorm.DB) UsersStore {
|
||||
return &users{DB: db}
|
||||
return &usersStore{DB: db}
|
||||
}
|
||||
|
||||
type ErrLoginSourceMismatch struct {
|
||||
@ -173,10 +173,10 @@ func (err ErrLoginSourceMismatch) Error() string {
|
||||
return fmt.Sprintf("login source mismatch: %v", err.args)
|
||||
}
|
||||
|
||||
func (db *users) Authenticate(ctx context.Context, login, password string, loginSourceID int64) (*User, error) {
|
||||
func (s *usersStore) Authenticate(ctx context.Context, login, password string, loginSourceID int64) (*User, error) {
|
||||
login = strings.ToLower(login)
|
||||
|
||||
query := db.WithContext(ctx)
|
||||
query := s.WithContext(ctx)
|
||||
if strings.Contains(login, "@") {
|
||||
query = query.Where("email = ?", login)
|
||||
} else {
|
||||
@ -244,7 +244,7 @@ func (db *users) Authenticate(ctx context.Context, login, password string, login
|
||||
return nil, fmt.Errorf("invalid pattern for attribute 'username' [%s]: must be valid alpha or numeric or dash(-_) or dot characters", extAccount.Name)
|
||||
}
|
||||
|
||||
return db.Create(ctx, extAccount.Name, extAccount.Email,
|
||||
return s.Create(ctx, extAccount.Name, extAccount.Email,
|
||||
CreateUserOptions{
|
||||
FullName: extAccount.FullName,
|
||||
LoginSource: authSourceID,
|
||||
@ -257,13 +257,13 @@ func (db *users) Authenticate(ctx context.Context, login, password string, login
|
||||
)
|
||||
}
|
||||
|
||||
func (db *users) ChangeUsername(ctx context.Context, userID int64, newUsername string) error {
|
||||
func (s *usersStore) ChangeUsername(ctx context.Context, userID int64, newUsername string) error {
|
||||
err := isUsernameAllowed(newUsername)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if db.IsUsernameUsed(ctx, newUsername, userID) {
|
||||
if s.IsUsernameUsed(ctx, newUsername, userID) {
|
||||
return ErrUserAlreadyExist{
|
||||
args: errutil.Args{
|
||||
"name": newUsername,
|
||||
@ -271,12 +271,12 @@ func (db *users) ChangeUsername(ctx context.Context, userID int64, newUsername s
|
||||
}
|
||||
}
|
||||
|
||||
user, err := db.GetByID(ctx, userID)
|
||||
user, err := s.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get user")
|
||||
}
|
||||
|
||||
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
err := tx.Model(&User{}).
|
||||
Where("id = ?", user.ID).
|
||||
Updates(map[string]any{
|
||||
@ -338,9 +338,9 @@ func (db *users) ChangeUsername(ctx context.Context, userID int64, newUsername s
|
||||
})
|
||||
}
|
||||
|
||||
func (db *users) Count(ctx context.Context) int64 {
|
||||
func (s *usersStore) Count(ctx context.Context) int64 {
|
||||
var count int64
|
||||
db.WithContext(ctx).Model(&User{}).Where("type = ?", UserTypeIndividual).Count(&count)
|
||||
s.WithContext(ctx).Model(&User{}).Where("type = ?", UserTypeIndividual).Count(&count)
|
||||
return count
|
||||
}
|
||||
|
||||
@ -393,13 +393,13 @@ func (err ErrEmailAlreadyUsed) Error() string {
|
||||
return fmt.Sprintf("email has been used: %v", err.args)
|
||||
}
|
||||
|
||||
func (db *users) Create(ctx context.Context, username, email string, opts CreateUserOptions) (*User, error) {
|
||||
func (s *usersStore) Create(ctx context.Context, username, email string, opts CreateUserOptions) (*User, error) {
|
||||
err := isUsernameAllowed(username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if db.IsUsernameUsed(ctx, username, 0) {
|
||||
if s.IsUsernameUsed(ctx, username, 0) {
|
||||
return nil, ErrUserAlreadyExist{
|
||||
args: errutil.Args{
|
||||
"name": username,
|
||||
@ -408,7 +408,7 @@ func (db *users) Create(ctx context.Context, username, email string, opts Create
|
||||
}
|
||||
|
||||
email = strings.ToLower(strings.TrimSpace(email))
|
||||
_, err = db.GetByEmail(ctx, email)
|
||||
_, err = s.GetByEmail(ctx, email)
|
||||
if err == nil {
|
||||
return nil, ErrEmailAlreadyUsed{
|
||||
args: errutil.Args{
|
||||
@ -446,17 +446,17 @@ func (db *users) Create(ctx context.Context, username, email string, opts Create
|
||||
}
|
||||
user.Password = userutil.EncodePassword(user.Password, user.Salt)
|
||||
|
||||
return user, db.WithContext(ctx).Create(user).Error
|
||||
return user, s.WithContext(ctx).Create(user).Error
|
||||
}
|
||||
|
||||
func (db *users) DeleteCustomAvatar(ctx context.Context, userID int64) error {
|
||||
func (s *usersStore) DeleteCustomAvatar(ctx context.Context, userID int64) error {
|
||||
_ = os.Remove(userutil.CustomAvatarPath(userID))
|
||||
return db.WithContext(ctx).
|
||||
return s.WithContext(ctx).
|
||||
Model(&User{}).
|
||||
Where("id = ?", userID).
|
||||
Updates(map[string]any{
|
||||
"use_custom_avatar": false,
|
||||
"updated_unix": db.NowFunc().Unix(),
|
||||
"updated_unix": s.NowFunc().Unix(),
|
||||
}).
|
||||
Error
|
||||
}
|
||||
@ -491,8 +491,8 @@ func (err ErrUserHasOrgs) Error() string {
|
||||
return fmt.Sprintf("user still has organization membership: %v", err.args)
|
||||
}
|
||||
|
||||
func (db *users) DeleteByID(ctx context.Context, userID int64, skipRewriteAuthorizedKeys bool) error {
|
||||
user, err := db.GetByID(ctx, userID)
|
||||
func (s *usersStore) DeleteByID(ctx context.Context, userID int64, skipRewriteAuthorizedKeys bool) error {
|
||||
user, err := s.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
if IsErrUserNotExist(err) {
|
||||
return nil
|
||||
@ -503,14 +503,14 @@ func (db *users) DeleteByID(ctx context.Context, userID int64, skipRewriteAuthor
|
||||
// Double check the user is not a direct owner of any repository and not a
|
||||
// member of any organization.
|
||||
var count int64
|
||||
err = db.WithContext(ctx).Model(&Repository{}).Where("owner_id = ?", userID).Count(&count).Error
|
||||
err = s.WithContext(ctx).Model(&Repository{}).Where("owner_id = ?", userID).Count(&count).Error
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "count repositories")
|
||||
} else if count > 0 {
|
||||
return ErrUserOwnRepos{args: errutil.Args{"userID": userID}}
|
||||
}
|
||||
|
||||
err = db.WithContext(ctx).Model(&OrgUser{}).Where("uid = ?", userID).Count(&count).Error
|
||||
err = s.WithContext(ctx).Model(&OrgUser{}).Where("uid = ?", userID).Count(&count).Error
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "count organization membership")
|
||||
} else if count > 0 {
|
||||
@ -518,7 +518,7 @@ func (db *users) DeleteByID(ctx context.Context, userID int64, skipRewriteAuthor
|
||||
}
|
||||
|
||||
needsRewriteAuthorizedKeys := false
|
||||
err = db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
err = s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
/*
|
||||
Equivalent SQL for PostgreSQL:
|
||||
|
||||
@ -645,7 +645,7 @@ func (db *users) DeleteByID(ctx context.Context, userID int64, skipRewriteAuthor
|
||||
_ = os.Remove(userutil.CustomAvatarPath(userID))
|
||||
|
||||
if needsRewriteAuthorizedKeys {
|
||||
err = NewPublicKeysStore(db.DB).RewriteAuthorizedKeys()
|
||||
err = NewPublicKeysStore(s.DB).RewriteAuthorizedKeys()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, `rewrite "authorized_keys" file`)
|
||||
}
|
||||
@ -655,15 +655,15 @@ func (db *users) DeleteByID(ctx context.Context, userID int64, skipRewriteAuthor
|
||||
|
||||
// NOTE: We do not take context.Context here because this operation in practice
|
||||
// could much longer than the general request timeout (e.g. one minute).
|
||||
func (db *users) DeleteInactivated() error {
|
||||
func (s *usersStore) DeleteInactivated() error {
|
||||
var userIDs []int64
|
||||
err := db.Model(&User{}).Where("is_active = ?", false).Pluck("id", &userIDs).Error
|
||||
err := s.Model(&User{}).Where("is_active = ?", false).Pluck("id", &userIDs).Error
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get inactivated user IDs")
|
||||
}
|
||||
|
||||
for _, userID := range userIDs {
|
||||
err = db.DeleteByID(context.Background(), userID, true)
|
||||
err = s.DeleteByID(context.Background(), userID, true)
|
||||
if err != nil {
|
||||
// Skip users that may had set to inactivated by admins.
|
||||
if IsErrUserOwnRepos(err) || IsErrUserHasOrgs(err) {
|
||||
@ -672,14 +672,14 @@ func (db *users) DeleteInactivated() error {
|
||||
return errors.Wrapf(err, "delete user with ID %d", userID)
|
||||
}
|
||||
}
|
||||
err = NewPublicKeysStore(db.DB).RewriteAuthorizedKeys()
|
||||
err = NewPublicKeysStore(s.DB).RewriteAuthorizedKeys()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, `rewrite "authorized_keys" file`)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*users) recountFollows(tx *gorm.DB, userID, followID int64) error {
|
||||
func (*usersStore) recountFollows(tx *gorm.DB, userID, followID int64) error {
|
||||
/*
|
||||
Equivalent SQL for PostgreSQL:
|
||||
|
||||
@ -722,12 +722,12 @@ func (*users) recountFollows(tx *gorm.DB, userID, followID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *users) Follow(ctx context.Context, userID, followID int64) error {
|
||||
func (s *usersStore) Follow(ctx context.Context, userID, followID int64) error {
|
||||
if userID == followID {
|
||||
return nil
|
||||
}
|
||||
|
||||
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
f := &Follow{
|
||||
UserID: userID,
|
||||
FollowID: followID,
|
||||
@ -739,26 +739,26 @@ func (db *users) Follow(ctx context.Context, userID, followID int64) error {
|
||||
return nil // Relation already exists
|
||||
}
|
||||
|
||||
return db.recountFollows(tx, userID, followID)
|
||||
return s.recountFollows(tx, userID, followID)
|
||||
})
|
||||
}
|
||||
|
||||
func (db *users) Unfollow(ctx context.Context, userID, followID int64) error {
|
||||
func (s *usersStore) Unfollow(ctx context.Context, userID, followID int64) error {
|
||||
if userID == followID {
|
||||
return nil
|
||||
}
|
||||
|
||||
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
err := tx.Where("user_id = ? AND follow_id = ?", userID, followID).Delete(&Follow{}).Error
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "delete")
|
||||
}
|
||||
return db.recountFollows(tx, userID, followID)
|
||||
return s.recountFollows(tx, userID, followID)
|
||||
})
|
||||
}
|
||||
|
||||
func (db *users) IsFollowing(ctx context.Context, userID, followID int64) bool {
|
||||
return db.WithContext(ctx).Where("user_id = ? AND follow_id = ?", userID, followID).First(&Follow{}).Error == nil
|
||||
func (s *usersStore) IsFollowing(ctx context.Context, userID, followID int64) bool {
|
||||
return s.WithContext(ctx).Where("user_id = ? AND follow_id = ?", userID, followID).First(&Follow{}).Error == nil
|
||||
}
|
||||
|
||||
var _ errutil.NotFound = (*ErrUserNotExist)(nil)
|
||||
@ -782,7 +782,7 @@ func (ErrUserNotExist) NotFound() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (db *users) GetByEmail(ctx context.Context, email string) (*User, error) {
|
||||
func (s *usersStore) GetByEmail(ctx context.Context, email string) (*User, error) {
|
||||
if email == "" {
|
||||
return nil, ErrUserNotExist{args: errutil.Args{"email": email}}
|
||||
}
|
||||
@ -801,10 +801,10 @@ func (db *users) GetByEmail(ctx context.Context, email string) (*User, error) {
|
||||
)
|
||||
*/
|
||||
user := new(User)
|
||||
err := db.WithContext(ctx).
|
||||
err := s.WithContext(ctx).
|
||||
Joins(dbutil.Quote("LEFT JOIN email_address ON email_address.uid = %s.id", "user"), true).
|
||||
Where(dbutil.Quote("%s.type = ?", "user"), UserTypeIndividual).
|
||||
Where(db.
|
||||
Where(s.
|
||||
Where(dbutil.Quote("%[1]s.email = ? AND %[1]s.is_active = ?", "user"), email, true).
|
||||
Or("email_address.email = ? AND email_address.is_activated = ?", email, true),
|
||||
).
|
||||
@ -819,9 +819,9 @@ func (db *users) GetByEmail(ctx context.Context, email string) (*User, error) {
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (db *users) GetByID(ctx context.Context, id int64) (*User, error) {
|
||||
func (s *usersStore) GetByID(ctx context.Context, id int64) (*User, error) {
|
||||
user := new(User)
|
||||
err := db.WithContext(ctx).Where("id = ?", id).First(user).Error
|
||||
err := s.WithContext(ctx).Where("id = ?", id).First(user).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, ErrUserNotExist{args: errutil.Args{"userID": id}}
|
||||
@ -831,9 +831,9 @@ func (db *users) GetByID(ctx context.Context, id int64) (*User, error) {
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (db *users) GetByUsername(ctx context.Context, username string) (*User, error) {
|
||||
func (s *usersStore) GetByUsername(ctx context.Context, username string) (*User, error) {
|
||||
user := new(User)
|
||||
err := db.WithContext(ctx).Where("lower_name = ?", strings.ToLower(username)).First(user).Error
|
||||
err := s.WithContext(ctx).Where("lower_name = ?", strings.ToLower(username)).First(user).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, ErrUserNotExist{args: errutil.Args{"name": username}}
|
||||
@ -843,9 +843,9 @@ func (db *users) GetByUsername(ctx context.Context, username string) (*User, err
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (db *users) GetByKeyID(ctx context.Context, keyID int64) (*User, error) {
|
||||
func (s *usersStore) GetByKeyID(ctx context.Context, keyID int64) (*User, error) {
|
||||
user := new(User)
|
||||
err := db.WithContext(ctx).
|
||||
err := s.WithContext(ctx).
|
||||
Joins(dbutil.Quote("JOIN public_key ON public_key.owner_id = %s.id", "user")).
|
||||
Where("public_key.id = ?", keyID).
|
||||
First(user).
|
||||
@ -859,29 +859,29 @@ func (db *users) GetByKeyID(ctx context.Context, keyID int64) (*User, error) {
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (db *users) GetMailableEmailsByUsernames(ctx context.Context, usernames []string) ([]string, error) {
|
||||
func (s *usersStore) GetMailableEmailsByUsernames(ctx context.Context, usernames []string) ([]string, error) {
|
||||
emails := make([]string, 0, len(usernames))
|
||||
return emails, db.WithContext(ctx).
|
||||
return emails, s.WithContext(ctx).
|
||||
Model(&User{}).
|
||||
Select("email").
|
||||
Where("lower_name IN (?) AND is_active = ?", usernames, true).
|
||||
Find(&emails).Error
|
||||
}
|
||||
|
||||
func (db *users) IsUsernameUsed(ctx context.Context, username string, excludeUserId int64) bool {
|
||||
func (s *usersStore) IsUsernameUsed(ctx context.Context, username string, excludeUserId int64) bool {
|
||||
if username == "" {
|
||||
return false
|
||||
}
|
||||
return db.WithContext(ctx).
|
||||
return s.WithContext(ctx).
|
||||
Select("id").
|
||||
Where("lower_name = ? AND id != ?", strings.ToLower(username), excludeUserId).
|
||||
First(&User{}).
|
||||
Error != gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (db *users) List(ctx context.Context, page, pageSize int) ([]*User, error) {
|
||||
func (s *usersStore) List(ctx context.Context, page, pageSize int) ([]*User, error) {
|
||||
users := make([]*User, 0, pageSize)
|
||||
return users, db.WithContext(ctx).
|
||||
return users, s.WithContext(ctx).
|
||||
Where("type = ?", UserTypeIndividual).
|
||||
Limit(pageSize).Offset((page - 1) * pageSize).
|
||||
Order("id ASC").
|
||||
@ -889,7 +889,7 @@ func (db *users) List(ctx context.Context, page, pageSize int) ([]*User, error)
|
||||
Error
|
||||
}
|
||||
|
||||
func (db *users) ListFollowers(ctx context.Context, userID int64, page, pageSize int) ([]*User, error) {
|
||||
func (s *usersStore) ListFollowers(ctx context.Context, userID int64, page, pageSize int) ([]*User, error) {
|
||||
/*
|
||||
Equivalent SQL for PostgreSQL:
|
||||
|
||||
@ -900,7 +900,7 @@ func (db *users) ListFollowers(ctx context.Context, userID int64, page, pageSize
|
||||
LIMIT @limit OFFSET @offset
|
||||
*/
|
||||
users := make([]*User, 0, pageSize)
|
||||
return users, db.WithContext(ctx).
|
||||
return users, s.WithContext(ctx).
|
||||
Joins(dbutil.Quote("LEFT JOIN follow ON follow.user_id = %s.id", "user")).
|
||||
Where("follow.follow_id = ?", userID).
|
||||
Limit(pageSize).Offset((page - 1) * pageSize).
|
||||
@ -909,7 +909,7 @@ func (db *users) ListFollowers(ctx context.Context, userID int64, page, pageSize
|
||||
Error
|
||||
}
|
||||
|
||||
func (db *users) ListFollowings(ctx context.Context, userID int64, page, pageSize int) ([]*User, error) {
|
||||
func (s *usersStore) ListFollowings(ctx context.Context, userID int64, page, pageSize int) ([]*User, error) {
|
||||
/*
|
||||
Equivalent SQL for PostgreSQL:
|
||||
|
||||
@ -920,7 +920,7 @@ func (db *users) ListFollowings(ctx context.Context, userID int64, page, pageSiz
|
||||
LIMIT @limit OFFSET @offset
|
||||
*/
|
||||
users := make([]*User, 0, pageSize)
|
||||
return users, db.WithContext(ctx).
|
||||
return users, s.WithContext(ctx).
|
||||
Joins(dbutil.Quote("LEFT JOIN follow ON follow.follow_id = %s.id", "user")).
|
||||
Where("follow.user_id = ?", userID).
|
||||
Limit(pageSize).Offset((page - 1) * pageSize).
|
||||
@ -948,8 +948,8 @@ func searchUserByName(ctx context.Context, db *gorm.DB, userType UserType, keywo
|
||||
return users, count, tx.Order(orderBy).Limit(pageSize).Offset((page - 1) * pageSize).Find(&users).Error
|
||||
}
|
||||
|
||||
func (db *users) SearchByName(ctx context.Context, keyword string, page, pageSize int, orderBy string) ([]*User, int64, error) {
|
||||
return searchUserByName(ctx, db.DB, UserTypeIndividual, keyword, page, pageSize, orderBy)
|
||||
func (s *usersStore) SearchByName(ctx context.Context, keyword string, page, pageSize int, orderBy string) ([]*User, int64, error) {
|
||||
return searchUserByName(ctx, s.DB, UserTypeIndividual, keyword, page, pageSize, orderBy)
|
||||
}
|
||||
|
||||
type UpdateUserOptions struct {
|
||||
@ -979,9 +979,9 @@ type UpdateUserOptions struct {
|
||||
AvatarEmail *string
|
||||
}
|
||||
|
||||
func (db *users) Update(ctx context.Context, userID int64, opts UpdateUserOptions) error {
|
||||
func (s *usersStore) Update(ctx context.Context, userID int64, opts UpdateUserOptions) error {
|
||||
updates := map[string]any{
|
||||
"updated_unix": db.NowFunc().Unix(),
|
||||
"updated_unix": s.NowFunc().Unix(),
|
||||
}
|
||||
|
||||
if opts.LoginSource != nil {
|
||||
@ -1012,7 +1012,7 @@ func (db *users) Update(ctx context.Context, userID int64, opts UpdateUserOption
|
||||
updates["full_name"] = strutil.Truncate(*opts.FullName, 255)
|
||||
}
|
||||
if opts.Email != nil {
|
||||
_, err := db.GetByEmail(ctx, *opts.Email)
|
||||
_, err := s.GetByEmail(ctx, *opts.Email)
|
||||
if err == nil {
|
||||
return ErrEmailAlreadyUsed{args: errutil.Args{"email": *opts.Email}}
|
||||
} else if !IsErrUserNotExist(err) {
|
||||
@ -1063,28 +1063,28 @@ func (db *users) Update(ctx context.Context, userID int64, opts UpdateUserOption
|
||||
updates["avatar_email"] = strutil.Truncate(*opts.AvatarEmail, 255)
|
||||
}
|
||||
|
||||
return db.WithContext(ctx).Model(&User{}).Where("id = ?", userID).Updates(updates).Error
|
||||
return s.WithContext(ctx).Model(&User{}).Where("id = ?", userID).Updates(updates).Error
|
||||
}
|
||||
|
||||
func (db *users) UseCustomAvatar(ctx context.Context, userID int64, avatar []byte) error {
|
||||
func (s *usersStore) UseCustomAvatar(ctx context.Context, userID int64, avatar []byte) error {
|
||||
err := userutil.SaveAvatar(userID, avatar)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "save avatar")
|
||||
}
|
||||
|
||||
return db.WithContext(ctx).
|
||||
return s.WithContext(ctx).
|
||||
Model(&User{}).
|
||||
Where("id = ?", userID).
|
||||
Updates(map[string]any{
|
||||
"use_custom_avatar": true,
|
||||
"updated_unix": db.NowFunc().Unix(),
|
||||
"updated_unix": s.NowFunc().Unix(),
|
||||
}).
|
||||
Error
|
||||
}
|
||||
|
||||
func (db *users) AddEmail(ctx context.Context, userID int64, email string, isActivated bool) error {
|
||||
func (s *usersStore) AddEmail(ctx context.Context, userID int64, email string, isActivated bool) error {
|
||||
email = strings.ToLower(strings.TrimSpace(email))
|
||||
_, err := db.GetByEmail(ctx, email)
|
||||
_, err := s.GetByEmail(ctx, email)
|
||||
if err == nil {
|
||||
return ErrEmailAlreadyUsed{
|
||||
args: errutil.Args{
|
||||
@ -1095,7 +1095,7 @@ func (db *users) AddEmail(ctx context.Context, userID int64, email string, isAct
|
||||
return errors.Wrap(err, "check user by email")
|
||||
}
|
||||
|
||||
return db.WithContext(ctx).Create(
|
||||
return s.WithContext(ctx).Create(
|
||||
&EmailAddress{
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
@ -1125,8 +1125,8 @@ func (ErrEmailNotExist) NotFound() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (db *users) GetEmail(ctx context.Context, userID int64, email string, needsActivated bool) (*EmailAddress, error) {
|
||||
tx := db.WithContext(ctx).Where("uid = ? AND email = ?", userID, email)
|
||||
func (s *usersStore) GetEmail(ctx context.Context, userID int64, email string, needsActivated bool) (*EmailAddress, error) {
|
||||
tx := s.WithContext(ctx).Where("uid = ? AND email = ?", userID, email)
|
||||
if needsActivated {
|
||||
tx = tx.Where("is_activated = ?", true)
|
||||
}
|
||||
@ -1146,14 +1146,14 @@ func (db *users) GetEmail(ctx context.Context, userID int64, email string, needs
|
||||
return emailAddress, nil
|
||||
}
|
||||
|
||||
func (db *users) ListEmails(ctx context.Context, userID int64) ([]*EmailAddress, error) {
|
||||
user, err := db.GetByID(ctx, userID)
|
||||
func (s *usersStore) ListEmails(ctx context.Context, userID int64) ([]*EmailAddress, error) {
|
||||
user, err := s.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get user")
|
||||
}
|
||||
|
||||
var emails []*EmailAddress
|
||||
err = db.WithContext(ctx).Where("uid = ?", userID).Order("id ASC").Find(&emails).Error
|
||||
err = s.WithContext(ctx).Where("uid = ?", userID).Order("id ASC").Find(&emails).Error
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "list emails")
|
||||
}
|
||||
@ -1179,9 +1179,9 @@ func (db *users) ListEmails(ctx context.Context, userID int64) ([]*EmailAddress,
|
||||
return emails, nil
|
||||
}
|
||||
|
||||
func (db *users) MarkEmailActivated(ctx context.Context, userID int64, email string) error {
|
||||
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
err := db.WithContext(ctx).
|
||||
func (s *usersStore) MarkEmailActivated(ctx context.Context, userID int64, email string) error {
|
||||
return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
err := s.WithContext(ctx).
|
||||
Model(&EmailAddress{}).
|
||||
Where("uid = ? AND email = ?", userID, email).
|
||||
Update("is_activated", true).
|
||||
@ -1209,9 +1209,9 @@ func (err ErrEmailNotVerified) Error() string {
|
||||
return fmt.Sprintf("email has not been verified: %v", err.args)
|
||||
}
|
||||
|
||||
func (db *users) MarkEmailPrimary(ctx context.Context, userID int64, email string) error {
|
||||
func (s *usersStore) MarkEmailPrimary(ctx context.Context, userID int64, email string) error {
|
||||
var emailAddress EmailAddress
|
||||
err := db.WithContext(ctx).Where("uid = ? AND email = ?", userID, email).First(&emailAddress).Error
|
||||
err := s.WithContext(ctx).Where("uid = ? AND email = ?", userID, email).First(&emailAddress).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return ErrEmailNotExist{args: errutil.Args{"email": email}}
|
||||
@ -1223,12 +1223,12 @@ func (db *users) MarkEmailPrimary(ctx context.Context, userID int64, email strin
|
||||
return ErrEmailNotVerified{args: errutil.Args{"email": email}}
|
||||
}
|
||||
|
||||
user, err := db.GetByID(ctx, userID)
|
||||
user, err := s.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get user")
|
||||
}
|
||||
|
||||
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
return s.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// Make sure the former primary email doesn't disappear.
|
||||
err = tx.FirstOrCreate(
|
||||
&EmailAddress{
|
||||
@ -1255,8 +1255,8 @@ func (db *users) MarkEmailPrimary(ctx context.Context, userID int64, email strin
|
||||
})
|
||||
}
|
||||
|
||||
func (db *users) DeleteEmail(ctx context.Context, userID int64, email string) error {
|
||||
return db.WithContext(ctx).Where("uid = ? AND email = ?", userID, email).Delete(&EmailAddress{}).Error
|
||||
func (s *usersStore) DeleteEmail(ctx context.Context, userID int64, email string) error {
|
||||
return s.WithContext(ctx).Where("uid = ? AND email = ?", userID, email).Delete(&EmailAddress{}).Error
|
||||
}
|
||||
|
||||
// UserType indicates the type of the user account.
|
||||
|
@ -84,13 +84,13 @@ func TestUsers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
db := &users{
|
||||
DB: newTestDB(t, "users"),
|
||||
db := &usersStore{
|
||||
DB: newTestDB(t, "usersStore"),
|
||||
}
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
test func(t *testing.T, ctx context.Context, db *users)
|
||||
test func(t *testing.T, ctx context.Context, db *usersStore)
|
||||
}{
|
||||
{"Authenticate", usersAuthenticate},
|
||||
{"ChangeUsername", usersChangeUsername},
|
||||
@ -134,7 +134,7 @@ func TestUsers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func usersAuthenticate(t *testing.T, ctx context.Context, db *users) {
|
||||
func usersAuthenticate(t *testing.T, ctx context.Context, db *usersStore) {
|
||||
password := "pa$$word"
|
||||
alice, err := db.Create(ctx, "alice", "alice@example.com",
|
||||
CreateUserOptions{
|
||||
@ -229,7 +229,7 @@ func usersAuthenticate(t *testing.T, ctx context.Context, db *users) {
|
||||
})
|
||||
}
|
||||
|
||||
func usersChangeUsername(t *testing.T, ctx context.Context, db *users) {
|
||||
func usersChangeUsername(t *testing.T, ctx context.Context, db *usersStore) {
|
||||
alice, err := db.Create(
|
||||
ctx,
|
||||
"alice",
|
||||
@ -359,7 +359,7 @@ func usersChangeUsername(t *testing.T, ctx context.Context, db *users) {
|
||||
assert.Equal(t, strings.ToUpper(newUsername), alice.Name)
|
||||
}
|
||||
|
||||
func usersCount(t *testing.T, ctx context.Context, db *users) {
|
||||
func usersCount(t *testing.T, ctx context.Context, db *usersStore) {
|
||||
// Has no user initially
|
||||
got := db.Count(ctx)
|
||||
assert.Equal(t, int64(0), got)
|
||||
@ -382,7 +382,7 @@ func usersCount(t *testing.T, ctx context.Context, db *users) {
|
||||
assert.Equal(t, int64(1), got)
|
||||
}
|
||||
|
||||
func usersCreate(t *testing.T, ctx context.Context, db *users) {
|
||||
func usersCreate(t *testing.T, ctx context.Context, db *usersStore) {
|
||||
alice, err := db.Create(
|
||||
ctx,
|
||||
"alice",
|
||||
@ -430,7 +430,7 @@ func usersCreate(t *testing.T, ctx context.Context, db *users) {
|
||||
assert.Equal(t, db.NowFunc().Format(time.RFC3339), user.Updated.UTC().Format(time.RFC3339))
|
||||
}
|
||||
|
||||
func usersDeleteCustomAvatar(t *testing.T, ctx context.Context, db *users) {
|
||||
func usersDeleteCustomAvatar(t *testing.T, ctx context.Context, db *usersStore) {
|
||||
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -464,7 +464,7 @@ func usersDeleteCustomAvatar(t *testing.T, ctx context.Context, db *users) {
|
||||
assert.False(t, alice.UseCustomAvatar)
|
||||
}
|
||||
|
||||
func usersDeleteByID(t *testing.T, ctx context.Context, db *users) {
|
||||
func usersDeleteByID(t *testing.T, ctx context.Context, db *usersStore) {
|
||||
reposStore := NewReposStore(db.DB)
|
||||
|
||||
t.Run("user still has repository ownership", func(t *testing.T) {
|
||||
@ -674,7 +674,7 @@ func usersDeleteByID(t *testing.T, ctx context.Context, db *users) {
|
||||
assert.Equal(t, wantErr, err)
|
||||
}
|
||||
|
||||
func usersDeleteInactivated(t *testing.T, ctx context.Context, db *users) {
|
||||
func usersDeleteInactivated(t *testing.T, ctx context.Context, db *usersStore) {
|
||||
// User with repository ownership should be skipped
|
||||
alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{})
|
||||
require.NoError(t, err)
|
||||
@ -720,7 +720,7 @@ func usersDeleteInactivated(t *testing.T, ctx context.Context, db *users) {
|
||||
require.Len(t, users, 3)
|
||||
}
|
||||
|
||||
func usersGetByEmail(t *testing.T, ctx context.Context, db *users) {
|
||||
func usersGetByEmail(t *testing.T, ctx context.Context, db *usersStore) {
|
||||
t.Run("empty email", func(t *testing.T) {
|
||||
_, err := db.GetByEmail(ctx, "")
|
||||
wantErr := ErrUserNotExist{args: errutil.Args{"email": ""}}
|
||||
@ -781,7 +781,7 @@ func usersGetByEmail(t *testing.T, ctx context.Context, db *users) {
|
||||
})
|
||||
}
|
||||
|
||||
func usersGetByID(t *testing.T, ctx context.Context, db *users) {
|
||||
func usersGetByID(t *testing.T, ctx context.Context, db *usersStore) {
|
||||
alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -794,7 +794,7 @@ func usersGetByID(t *testing.T, ctx context.Context, db *users) {
|
||||
assert.Equal(t, wantErr, err)
|
||||
}
|
||||
|
||||
func usersGetByUsername(t *testing.T, ctx context.Context, db *users) {
|
||||
func usersGetByUsername(t *testing.T, ctx context.Context, db *usersStore) {
|
||||
alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -807,7 +807,7 @@ func usersGetByUsername(t *testing.T, ctx context.Context, db *users) {
|
||||
assert.Equal(t, wantErr, err)
|
||||
}
|
||||
|
||||
func usersGetByKeyID(t *testing.T, ctx context.Context, db *users) {
|
||||
func usersGetByKeyID(t *testing.T, ctx context.Context, db *usersStore) {
|
||||
alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -832,7 +832,7 @@ func usersGetByKeyID(t *testing.T, ctx context.Context, db *users) {
|
||||
assert.Equal(t, wantErr, err)
|
||||
}
|
||||
|
||||
func usersGetMailableEmailsByUsernames(t *testing.T, ctx context.Context, db *users) {
|
||||
func usersGetMailableEmailsByUsernames(t *testing.T, ctx context.Context, db *usersStore) {
|
||||
alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{})
|
||||
require.NoError(t, err)
|
||||
bob, err := db.Create(ctx, "bob", "bob@exmaple.com", CreateUserOptions{Activated: true})
|
||||
@ -846,7 +846,7 @@ func usersGetMailableEmailsByUsernames(t *testing.T, ctx context.Context, db *us
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func usersIsUsernameUsed(t *testing.T, ctx context.Context, db *users) {
|
||||
func usersIsUsernameUsed(t *testing.T, ctx context.Context, db *usersStore) {
|
||||
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -896,7 +896,7 @@ func usersIsUsernameUsed(t *testing.T, ctx context.Context, db *users) {
|
||||
}
|
||||
}
|
||||
|
||||
func usersList(t *testing.T, ctx context.Context, db *users) {
|
||||
func usersList(t *testing.T, ctx context.Context, db *usersStore) {
|
||||
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
|
||||
require.NoError(t, err)
|
||||
bob, err := db.Create(ctx, "bob", "bob@example.com", CreateUserOptions{})
|
||||
@ -929,7 +929,7 @@ func usersList(t *testing.T, ctx context.Context, db *users) {
|
||||
assert.Equal(t, bob.ID, got[1].ID)
|
||||
}
|
||||
|
||||
func usersListFollowers(t *testing.T, ctx context.Context, db *users) {
|
||||
func usersListFollowers(t *testing.T, ctx context.Context, db *usersStore) {
|
||||
john, err := db.Create(ctx, "john", "john@example.com", CreateUserOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -960,7 +960,7 @@ func usersListFollowers(t *testing.T, ctx context.Context, db *users) {
|
||||
assert.Equal(t, alice.ID, got[0].ID)
|
||||
}
|
||||
|
||||
func usersListFollowings(t *testing.T, ctx context.Context, db *users) {
|
||||
func usersListFollowings(t *testing.T, ctx context.Context, db *usersStore) {
|
||||
john, err := db.Create(ctx, "john", "john@example.com", CreateUserOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -991,7 +991,7 @@ func usersListFollowings(t *testing.T, ctx context.Context, db *users) {
|
||||
assert.Equal(t, alice.ID, got[0].ID)
|
||||
}
|
||||
|
||||
func usersSearchByName(t *testing.T, ctx context.Context, db *users) {
|
||||
func usersSearchByName(t *testing.T, ctx context.Context, db *usersStore) {
|
||||
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{FullName: "Alice Jordan"})
|
||||
require.NoError(t, err)
|
||||
bob, err := db.Create(ctx, "bob", "bob@example.com", CreateUserOptions{FullName: "Bob Jordan"})
|
||||
@ -1029,7 +1029,7 @@ func usersSearchByName(t *testing.T, ctx context.Context, db *users) {
|
||||
})
|
||||
}
|
||||
|
||||
func usersUpdate(t *testing.T, ctx context.Context, db *users) {
|
||||
func usersUpdate(t *testing.T, ctx context.Context, db *usersStore) {
|
||||
const oldPassword = "Password"
|
||||
alice, err := db.Create(
|
||||
ctx,
|
||||
@ -1142,7 +1142,7 @@ func usersUpdate(t *testing.T, ctx context.Context, db *users) {
|
||||
assertValues()
|
||||
}
|
||||
|
||||
func usersUseCustomAvatar(t *testing.T, ctx context.Context, db *users) {
|
||||
func usersUseCustomAvatar(t *testing.T, ctx context.Context, db *usersStore) {
|
||||
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -1180,7 +1180,7 @@ func TestIsUsernameAllowed(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func usersAddEmail(t *testing.T, ctx context.Context, db *users) {
|
||||
func usersAddEmail(t *testing.T, ctx context.Context, db *usersStore) {
|
||||
t.Run("multiple users can add the same unverified email", func(t *testing.T) {
|
||||
alice, err := db.Create(ctx, "alice", "unverified@example.com", CreateUserOptions{})
|
||||
require.NoError(t, err)
|
||||
@ -1197,7 +1197,7 @@ func usersAddEmail(t *testing.T, ctx context.Context, db *users) {
|
||||
})
|
||||
}
|
||||
|
||||
func usersGetEmail(t *testing.T, ctx context.Context, db *users) {
|
||||
func usersGetEmail(t *testing.T, ctx context.Context, db *usersStore) {
|
||||
const testUserID = 1
|
||||
const testEmail = "alice@example.com"
|
||||
_, err := db.GetEmail(ctx, testUserID, testEmail, false)
|
||||
@ -1229,7 +1229,7 @@ func usersGetEmail(t *testing.T, ctx context.Context, db *users) {
|
||||
assert.Equal(t, testEmail, got.Email)
|
||||
}
|
||||
|
||||
func usersListEmails(t *testing.T, ctx context.Context, db *users) {
|
||||
func usersListEmails(t *testing.T, ctx context.Context, db *usersStore) {
|
||||
t.Run("list emails with primary email", func(t *testing.T) {
|
||||
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
|
||||
require.NoError(t, err)
|
||||
@ -1265,7 +1265,7 @@ func usersListEmails(t *testing.T, ctx context.Context, db *users) {
|
||||
})
|
||||
}
|
||||
|
||||
func usersMarkEmailActivated(t *testing.T, ctx context.Context, db *users) {
|
||||
func usersMarkEmailActivated(t *testing.T, ctx context.Context, db *usersStore) {
|
||||
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -1283,7 +1283,7 @@ func usersMarkEmailActivated(t *testing.T, ctx context.Context, db *users) {
|
||||
assert.NotEqual(t, alice.Rands, gotAlice.Rands)
|
||||
}
|
||||
|
||||
func usersMarkEmailPrimary(t *testing.T, ctx context.Context, db *users) {
|
||||
func usersMarkEmailPrimary(t *testing.T, ctx context.Context, db *usersStore) {
|
||||
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
|
||||
require.NoError(t, err)
|
||||
err = db.AddEmail(ctx, alice.ID, "alice2@example.com", false)
|
||||
@ -1309,7 +1309,7 @@ func usersMarkEmailPrimary(t *testing.T, ctx context.Context, db *users) {
|
||||
assert.False(t, gotEmail.IsActivated)
|
||||
}
|
||||
|
||||
func usersDeleteEmail(t *testing.T, ctx context.Context, db *users) {
|
||||
func usersDeleteEmail(t *testing.T, ctx context.Context, db *usersStore) {
|
||||
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -1325,7 +1325,7 @@ func usersDeleteEmail(t *testing.T, ctx context.Context, db *users) {
|
||||
require.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func usersFollow(t *testing.T, ctx context.Context, db *users) {
|
||||
func usersFollow(t *testing.T, ctx context.Context, db *usersStore) {
|
||||
usersStore := NewUsersStore(db.DB)
|
||||
alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
|
||||
require.NoError(t, err)
|
||||
@ -1348,7 +1348,7 @@ func usersFollow(t *testing.T, ctx context.Context, db *users) {
|
||||
assert.Equal(t, 1, bob.NumFollowers)
|
||||
}
|
||||
|
||||
func usersIsFollowing(t *testing.T, ctx context.Context, db *users) {
|
||||
func usersIsFollowing(t *testing.T, ctx context.Context, db *usersStore) {
|
||||
usersStore := NewUsersStore(db.DB)
|
||||
alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
|
||||
require.NoError(t, err)
|
||||
@ -1369,7 +1369,7 @@ func usersIsFollowing(t *testing.T, ctx context.Context, db *users) {
|
||||
assert.False(t, got)
|
||||
}
|
||||
|
||||
func usersUnfollow(t *testing.T, ctx context.Context, db *users) {
|
||||
func usersUnfollow(t *testing.T, ctx context.Context, db *usersStore) {
|
||||
usersStore := NewUsersStore(db.DB)
|
||||
alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
|
||||
require.NoError(t, err)
|
||||
|
Loading…
x
Reference in New Issue
Block a user