db: pass context to tests by default (#7622)

[skip ci]
pull/7623/head
Joe Chen 2023-12-17 16:32:28 -05:00 committed by GitHub
parent 0c7b45ad1f
commit 25fdeaac49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 114 additions and 255 deletions

View File

@ -98,6 +98,7 @@ func TestAccessTokens(t *testing.T) {
} }
t.Parallel() t.Parallel()
ctx := context.Background()
tables := []any{new(AccessToken)} tables := []any{new(AccessToken)}
db := &accessTokens{ db := &accessTokens{
DB: dbtest.NewDB(t, "accessTokens", tables...), DB: dbtest.NewDB(t, "accessTokens", tables...),
@ -105,7 +106,7 @@ func TestAccessTokens(t *testing.T) {
for _, tc := range []struct { for _, tc := range []struct {
name string name string
test func(t *testing.T, db *accessTokens) test func(t *testing.T, ctx context.Context, db *accessTokens)
}{ }{
{"Create", accessTokensCreate}, {"Create", accessTokensCreate},
{"DeleteByID", accessTokensDeleteByID}, {"DeleteByID", accessTokensDeleteByID},
@ -118,7 +119,7 @@ func TestAccessTokens(t *testing.T) {
err := clearTables(t, db.DB, tables...) err := clearTables(t, db.DB, tables...)
require.NoError(t, err) require.NoError(t, err)
}) })
tc.test(t, db) tc.test(t, ctx, db)
}) })
if t.Failed() { if t.Failed() {
break break
@ -126,9 +127,7 @@ func TestAccessTokens(t *testing.T) {
} }
} }
func accessTokensCreate(t *testing.T, db *accessTokens) { func accessTokensCreate(t *testing.T, ctx context.Context, db *accessTokens) {
ctx := context.Background()
// Create first access token with name "Test" // Create first access token with name "Test"
token, err := db.Create(ctx, 1, "Test") token, err := db.Create(ctx, 1, "Test")
require.NoError(t, err) require.NoError(t, err)
@ -153,9 +152,7 @@ func accessTokensCreate(t *testing.T, db *accessTokens) {
assert.Equal(t, wantErr, err) assert.Equal(t, wantErr, err)
} }
func accessTokensDeleteByID(t *testing.T, db *accessTokens) { func accessTokensDeleteByID(t *testing.T, ctx context.Context, db *accessTokens) {
ctx := context.Background()
// Create an access token with name "Test" // Create an access token with name "Test"
token, err := db.Create(ctx, 1, "Test") token, err := db.Create(ctx, 1, "Test")
require.NoError(t, err) require.NoError(t, err)
@ -182,9 +179,7 @@ func accessTokensDeleteByID(t *testing.T, db *accessTokens) {
assert.Equal(t, wantErr, err) assert.Equal(t, wantErr, err)
} }
func accessTokensGetBySHA(t *testing.T, db *accessTokens) { func accessTokensGetBySHA(t *testing.T, ctx context.Context, db *accessTokens) {
ctx := context.Background()
// Create an access token with name "Test" // Create an access token with name "Test"
token, err := db.Create(ctx, 1, "Test") token, err := db.Create(ctx, 1, "Test")
require.NoError(t, err) require.NoError(t, err)
@ -203,9 +198,7 @@ func accessTokensGetBySHA(t *testing.T, db *accessTokens) {
assert.Equal(t, wantErr, err) assert.Equal(t, wantErr, err)
} }
func accessTokensList(t *testing.T, db *accessTokens) { func accessTokensList(t *testing.T, ctx context.Context, db *accessTokens) {
ctx := context.Background()
// Create two access tokens for user 1 // Create two access tokens for user 1
_, err := db.Create(ctx, 1, "user1_1") _, err := db.Create(ctx, 1, "user1_1")
require.NoError(t, err) require.NoError(t, err)
@ -228,9 +221,7 @@ func accessTokensList(t *testing.T, db *accessTokens) {
assert.Equal(t, "user1_2", tokens[1].Name) assert.Equal(t, "user1_2", tokens[1].Name)
} }
func accessTokensTouch(t *testing.T, db *accessTokens) { func accessTokensTouch(t *testing.T, ctx context.Context, db *accessTokens) {
ctx := context.Background()
// Create an access token with name "Test" // Create an access token with name "Test"
token, err := db.Create(ctx, 1, "Test") token, err := db.Create(ctx, 1, "Test")
require.NoError(t, err) require.NoError(t, err)

View File

@ -97,8 +97,9 @@ func TestActions(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip() t.Skip()
} }
t.Parallel()
ctx := context.Background()
t.Parallel()
tables := []any{new(Action), new(User), new(Repository), new(EmailAddress), new(Watch)} tables := []any{new(Action), new(User), new(Repository), new(EmailAddress), new(Watch)}
db := &actions{ db := &actions{
DB: dbtest.NewDB(t, "actions", tables...), DB: dbtest.NewDB(t, "actions", tables...),
@ -106,7 +107,7 @@ func TestActions(t *testing.T) {
for _, tc := range []struct { for _, tc := range []struct {
name string name string
test func(t *testing.T, db *actions) test func(t *testing.T, ctx context.Context, db *actions)
}{ }{
{"CommitRepo", actionsCommitRepo}, {"CommitRepo", actionsCommitRepo},
{"ListByOrganization", actionsListByOrganization}, {"ListByOrganization", actionsListByOrganization},
@ -125,7 +126,7 @@ func TestActions(t *testing.T) {
err := clearTables(t, db.DB, tables...) err := clearTables(t, db.DB, tables...)
require.NoError(t, err) require.NoError(t, err)
}) })
tc.test(t, db) tc.test(t, ctx, db)
}) })
if t.Failed() { if t.Failed() {
break break
@ -133,9 +134,7 @@ func TestActions(t *testing.T) {
} }
} }
func actionsCommitRepo(t *testing.T, db *actions) { func actionsCommitRepo(t *testing.T, ctx context.Context, db *actions) {
ctx := context.Background()
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
repo, err := NewReposStore(db.DB).Create(ctx, repo, err := NewReposStore(db.DB).Create(ctx,
@ -327,14 +326,12 @@ func actionsCommitRepo(t *testing.T, db *actions) {
}) })
} }
func actionsListByOrganization(t *testing.T, db *actions) { func actionsListByOrganization(t *testing.T, ctx context.Context, db *actions) {
if os.Getenv("GOGS_DATABASE_TYPE") != "postgres" { if os.Getenv("GOGS_DATABASE_TYPE") != "postgres" {
t.Skip("Skipping testing with not using PostgreSQL") t.Skip("Skipping testing with not using PostgreSQL")
return return
} }
ctx := context.Background()
conf.SetMockUI(t, conf.SetMockUI(t,
conf.UIOpts{ conf.UIOpts{
User: conf.UIUserOpts{ User: conf.UIUserOpts{
@ -375,14 +372,12 @@ func actionsListByOrganization(t *testing.T, db *actions) {
} }
} }
func actionsListByUser(t *testing.T, db *actions) { func actionsListByUser(t *testing.T, ctx context.Context, db *actions) {
if os.Getenv("GOGS_DATABASE_TYPE") != "postgres" { if os.Getenv("GOGS_DATABASE_TYPE") != "postgres" {
t.Skip("Skipping testing with not using PostgreSQL") t.Skip("Skipping testing with not using PostgreSQL")
return return
} }
ctx := context.Background()
conf.SetMockUI(t, conf.SetMockUI(t,
conf.UIOpts{ conf.UIOpts{
User: conf.UIUserOpts{ User: conf.UIUserOpts{
@ -442,9 +437,7 @@ func actionsListByUser(t *testing.T, db *actions) {
} }
} }
func actionsMergePullRequest(t *testing.T, db *actions) { func actionsMergePullRequest(t *testing.T, ctx context.Context, db *actions) {
ctx := context.Background()
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
repo, err := NewReposStore(db.DB).Create(ctx, repo, err := NewReposStore(db.DB).Create(ctx,
@ -489,9 +482,7 @@ func actionsMergePullRequest(t *testing.T, db *actions) {
assert.Equal(t, want, got) assert.Equal(t, want, got)
} }
func actionsMirrorSyncCreate(t *testing.T, db *actions) { func actionsMirrorSyncCreate(t *testing.T, ctx context.Context, db *actions) {
ctx := context.Background()
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
repo, err := NewReposStore(db.DB).Create(ctx, repo, err := NewReposStore(db.DB).Create(ctx,
@ -532,9 +523,7 @@ func actionsMirrorSyncCreate(t *testing.T, db *actions) {
assert.Equal(t, want, got) assert.Equal(t, want, got)
} }
func actionsMirrorSyncDelete(t *testing.T, db *actions) { func actionsMirrorSyncDelete(t *testing.T, ctx context.Context, db *actions) {
ctx := context.Background()
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
repo, err := NewReposStore(db.DB).Create(ctx, repo, err := NewReposStore(db.DB).Create(ctx,
@ -575,9 +564,7 @@ func actionsMirrorSyncDelete(t *testing.T, db *actions) {
assert.Equal(t, want, got) assert.Equal(t, want, got)
} }
func actionsMirrorSyncPush(t *testing.T, db *actions) { func actionsMirrorSyncPush(t *testing.T, ctx context.Context, db *actions) {
ctx := context.Background()
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
repo, err := NewReposStore(db.DB).Create(ctx, repo, err := NewReposStore(db.DB).Create(ctx,
@ -642,9 +629,7 @@ func actionsMirrorSyncPush(t *testing.T, db *actions) {
assert.Equal(t, want, got) assert.Equal(t, want, got)
} }
func actionsNewRepo(t *testing.T, db *actions) { func actionsNewRepo(t *testing.T, ctx context.Context, db *actions) {
ctx := context.Background()
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
repo, err := NewReposStore(db.DB).Create(ctx, repo, err := NewReposStore(db.DB).Create(ctx,
@ -719,9 +704,7 @@ func actionsNewRepo(t *testing.T, db *actions) {
}) })
} }
func actionsPushTag(t *testing.T, db *actions) { func actionsPushTag(t *testing.T, ctx context.Context, db *actions) {
ctx := context.Background()
// NOTE: We set a noop mock here to avoid data race with other tests that writes // NOTE: We set a noop mock here to avoid data race with other tests that writes
// to the mock server because this function holds a lock. // to the mock server because this function holds a lock.
conf.SetMockServer(t, conf.ServerOpts{}) conf.SetMockServer(t, conf.ServerOpts{})
@ -817,9 +800,7 @@ func actionsPushTag(t *testing.T, db *actions) {
}) })
} }
func actionsRenameRepo(t *testing.T, db *actions) { func actionsRenameRepo(t *testing.T, ctx context.Context, db *actions) {
ctx := context.Background()
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
repo, err := NewReposStore(db.DB).Create(ctx, repo, err := NewReposStore(db.DB).Create(ctx,
@ -856,9 +837,7 @@ func actionsRenameRepo(t *testing.T, db *actions) {
assert.Equal(t, want, got) assert.Equal(t, want, got)
} }
func actionsTransferRepo(t *testing.T, db *actions) { func actionsTransferRepo(t *testing.T, ctx context.Context, db *actions) {
ctx := context.Background()
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
bob, err := NewUsersStore(db.DB).Create(ctx, "bob", "bob@example.com", CreateUserOptions{}) bob, err := NewUsersStore(db.DB).Create(ctx, "bob", "bob@example.com", CreateUserOptions{})

View File

@ -23,6 +23,7 @@ func TestLFS(t *testing.T) {
} }
t.Parallel() t.Parallel()
ctx := context.Background()
tables := []any{new(LFSObject)} tables := []any{new(LFSObject)}
db := &lfs{ db := &lfs{
DB: dbtest.NewDB(t, "lfs", tables...), DB: dbtest.NewDB(t, "lfs", tables...),
@ -30,7 +31,7 @@ func TestLFS(t *testing.T) {
for _, tc := range []struct { for _, tc := range []struct {
name string name string
test func(t *testing.T, db *lfs) test func(t *testing.T, ctx context.Context, db *lfs)
}{ }{
{"CreateObject", lfsCreateObject}, {"CreateObject", lfsCreateObject},
{"GetObjectByOID", lfsGetObjectByOID}, {"GetObjectByOID", lfsGetObjectByOID},
@ -41,7 +42,7 @@ func TestLFS(t *testing.T) {
err := clearTables(t, db.DB, tables...) err := clearTables(t, db.DB, tables...)
require.NoError(t, err) require.NoError(t, err)
}) })
tc.test(t, db) tc.test(t, ctx, db)
}) })
if t.Failed() { if t.Failed() {
break break
@ -49,9 +50,7 @@ func TestLFS(t *testing.T) {
} }
} }
func lfsCreateObject(t *testing.T, db *lfs) { func lfsCreateObject(t *testing.T, ctx context.Context, db *lfs) {
ctx := context.Background()
// Create first LFS object // Create first LFS object
repoID := int64(1) repoID := int64(1)
oid := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f") oid := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f")
@ -68,9 +67,7 @@ func lfsCreateObject(t *testing.T, db *lfs) {
assert.Error(t, err) assert.Error(t, err)
} }
func lfsGetObjectByOID(t *testing.T, db *lfs) { func lfsGetObjectByOID(t *testing.T, ctx context.Context, db *lfs) {
ctx := context.Background()
// Create a LFS object // Create a LFS object
repoID := int64(1) repoID := int64(1)
oid := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f") oid := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f")
@ -87,9 +84,7 @@ func lfsGetObjectByOID(t *testing.T, db *lfs) {
assert.Equal(t, expErr, err) assert.Equal(t, expErr, err)
} }
func lfsGetObjectsByOIDs(t *testing.T, db *lfs) { func lfsGetObjectsByOIDs(t *testing.T, ctx context.Context, db *lfs) {
ctx := context.Background()
// Create two LFS objects // Create two LFS objects
repoID := int64(1) repoID := int64(1)
oid1 := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f") oid1 := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f")

View File

@ -163,6 +163,7 @@ func TestLoginSources(t *testing.T) {
} }
t.Parallel() t.Parallel()
ctx := context.Background()
tables := []any{new(LoginSource), new(User)} tables := []any{new(LoginSource), new(User)}
db := &loginSources{ db := &loginSources{
DB: dbtest.NewDB(t, "loginSources", tables...), DB: dbtest.NewDB(t, "loginSources", tables...),
@ -170,7 +171,7 @@ func TestLoginSources(t *testing.T) {
for _, tc := range []struct { for _, tc := range []struct {
name string name string
test func(t *testing.T, db *loginSources) test func(t *testing.T, ctx context.Context, db *loginSources)
}{ }{
{"Create", loginSourcesCreate}, {"Create", loginSourcesCreate},
{"Count", loginSourcesCount}, {"Count", loginSourcesCount},
@ -185,7 +186,7 @@ func TestLoginSources(t *testing.T) {
err := clearTables(t, db.DB, tables...) err := clearTables(t, db.DB, tables...)
require.NoError(t, err) require.NoError(t, err)
}) })
tc.test(t, db) tc.test(t, ctx, db)
}) })
if t.Failed() { if t.Failed() {
break break
@ -193,9 +194,7 @@ func TestLoginSources(t *testing.T) {
} }
} }
func loginSourcesCreate(t *testing.T, db *loginSources) { func loginSourcesCreate(t *testing.T, ctx context.Context, db *loginSources) {
ctx := context.Background()
// Create first login source with name "GitHub" // Create first login source with name "GitHub"
source, err := db.Create(ctx, source, err := db.Create(ctx,
CreateLoginSourceOptions{ CreateLoginSourceOptions{
@ -222,9 +221,7 @@ func loginSourcesCreate(t *testing.T, db *loginSources) {
assert.Equal(t, wantErr, err) assert.Equal(t, wantErr, err)
} }
func loginSourcesCount(t *testing.T, db *loginSources) { func loginSourcesCount(t *testing.T, ctx context.Context, db *loginSources) {
ctx := context.Background()
// Create two login sources, one in database and one as source file. // Create two login sources, one in database and one as source file.
_, err := db.Create(ctx, _, err := db.Create(ctx,
CreateLoginSourceOptions{ CreateLoginSourceOptions{
@ -246,9 +243,7 @@ func loginSourcesCount(t *testing.T, db *loginSources) {
assert.Equal(t, int64(3), db.Count(ctx)) assert.Equal(t, int64(3), db.Count(ctx))
} }
func loginSourcesDeleteByID(t *testing.T, db *loginSources) { func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSources) {
ctx := context.Background()
t.Run("delete but in used", func(t *testing.T) { t.Run("delete but in used", func(t *testing.T) {
source, err := db.Create(ctx, source, err := db.Create(ctx,
CreateLoginSourceOptions{ CreateLoginSourceOptions{
@ -315,9 +310,7 @@ func loginSourcesDeleteByID(t *testing.T, db *loginSources) {
assert.Equal(t, wantErr, err) assert.Equal(t, wantErr, err)
} }
func loginSourcesGetByID(t *testing.T, db *loginSources) { func loginSourcesGetByID(t *testing.T, ctx context.Context, db *loginSources) {
ctx := context.Background()
mock := NewMockLoginSourceFilesStore() mock := NewMockLoginSourceFilesStore()
mock.GetByIDFunc.SetDefaultHook(func(id int64) (*LoginSource, error) { mock.GetByIDFunc.SetDefaultHook(func(id int64) (*LoginSource, error) {
if id != 101 { if id != 101 {
@ -353,9 +346,7 @@ func loginSourcesGetByID(t *testing.T, db *loginSources) {
require.NoError(t, err) require.NoError(t, err)
} }
func loginSourcesList(t *testing.T, db *loginSources) { func loginSourcesList(t *testing.T, ctx context.Context, db *loginSources) {
ctx := context.Background()
mock := NewMockLoginSourceFilesStore() mock := NewMockLoginSourceFilesStore()
mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource { mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource {
if opts.OnlyActivated { if opts.OnlyActivated {
@ -404,9 +395,7 @@ func loginSourcesList(t *testing.T, db *loginSources) {
assert.Equal(t, 2, len(sources), "number of sources") assert.Equal(t, 2, len(sources), "number of sources")
} }
func loginSourcesResetNonDefault(t *testing.T, db *loginSources) { func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSources) {
ctx := context.Background()
mock := NewMockLoginSourceFilesStore() mock := NewMockLoginSourceFilesStore()
mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource { mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource {
mockFile := NewMockLoginSourceFileStore() mockFile := NewMockLoginSourceFileStore()
@ -461,9 +450,7 @@ func loginSourcesResetNonDefault(t *testing.T, db *loginSources) {
assert.False(t, source2.IsDefault) assert.False(t, source2.IsDefault)
} }
func loginSourcesSave(t *testing.T, db *loginSources) { func loginSourcesSave(t *testing.T, ctx context.Context, db *loginSources) {
ctx := context.Background()
t.Run("save to database", func(t *testing.T) { t.Run("save to database", func(t *testing.T) {
// Create a login source with name "GitHub" // Create a login source with name "GitHub"
source, err := db.Create(ctx, source, err := db.Create(ctx,

View File

@ -66,6 +66,7 @@ func TestNotices(t *testing.T) {
} }
t.Parallel() t.Parallel()
ctx := context.Background()
tables := []any{new(Notice)} tables := []any{new(Notice)}
db := &notices{ db := &notices{
DB: dbtest.NewDB(t, "notices", tables...), DB: dbtest.NewDB(t, "notices", tables...),
@ -73,7 +74,7 @@ func TestNotices(t *testing.T) {
for _, tc := range []struct { for _, tc := range []struct {
name string name string
test func(t *testing.T, db *notices) test func(t *testing.T, ctx context.Context, db *notices)
}{ }{
{"Create", noticesCreate}, {"Create", noticesCreate},
{"DeleteByIDs", noticesDeleteByIDs}, {"DeleteByIDs", noticesDeleteByIDs},
@ -86,7 +87,7 @@ func TestNotices(t *testing.T) {
err := clearTables(t, db.DB, tables...) err := clearTables(t, db.DB, tables...)
require.NoError(t, err) require.NoError(t, err)
}) })
tc.test(t, db) tc.test(t, ctx, db)
}) })
if t.Failed() { if t.Failed() {
break break
@ -94,9 +95,7 @@ func TestNotices(t *testing.T) {
} }
} }
func noticesCreate(t *testing.T, db *notices) { func noticesCreate(t *testing.T, ctx context.Context, db *notices) {
ctx := context.Background()
err := db.Create(ctx, NoticeTypeRepository, "test") err := db.Create(ctx, NoticeTypeRepository, "test")
require.NoError(t, err) require.NoError(t, err)
@ -104,9 +103,7 @@ func noticesCreate(t *testing.T, db *notices) {
assert.Equal(t, int64(1), count) assert.Equal(t, int64(1), count)
} }
func noticesDeleteByIDs(t *testing.T, db *notices) { func noticesDeleteByIDs(t *testing.T, ctx context.Context, db *notices) {
ctx := context.Background()
err := db.Create(ctx, NoticeTypeRepository, "test") err := db.Create(ctx, NoticeTypeRepository, "test")
require.NoError(t, err) require.NoError(t, err)
@ -126,9 +123,7 @@ func noticesDeleteByIDs(t *testing.T, db *notices) {
assert.Equal(t, int64(0), count) assert.Equal(t, int64(0), count)
} }
func noticesDeleteAll(t *testing.T, db *notices) { func noticesDeleteAll(t *testing.T, ctx context.Context, db *notices) {
ctx := context.Background()
err := db.Create(ctx, NoticeTypeRepository, "test") err := db.Create(ctx, NoticeTypeRepository, "test")
require.NoError(t, err) require.NoError(t, err)
@ -139,9 +134,7 @@ func noticesDeleteAll(t *testing.T, db *notices) {
assert.Equal(t, int64(0), count) assert.Equal(t, int64(0), count)
} }
func noticesList(t *testing.T, db *notices) { func noticesList(t *testing.T, ctx context.Context, db *notices) {
ctx := context.Background()
err := db.Create(ctx, NoticeTypeRepository, "test 1") err := db.Create(ctx, NoticeTypeRepository, "test 1")
require.NoError(t, err) require.NoError(t, err)
err = db.Create(ctx, NoticeTypeRepository, "test 2") err = db.Create(ctx, NoticeTypeRepository, "test 2")
@ -161,9 +154,7 @@ func noticesList(t *testing.T, db *notices) {
require.Len(t, got, 2) require.Len(t, got, 2)
} }
func noticesCount(t *testing.T, db *notices) { func noticesCount(t *testing.T, ctx context.Context, db *notices) {
ctx := context.Background()
count := db.Count(ctx) count := db.Count(ctx)
assert.Equal(t, int64(0), count) assert.Equal(t, int64(0), count)

View File

@ -21,6 +21,7 @@ func TestOrgs(t *testing.T) {
} }
t.Parallel() t.Parallel()
ctx := context.Background()
tables := []any{new(User), new(EmailAddress), new(OrgUser)} tables := []any{new(User), new(EmailAddress), new(OrgUser)}
db := &orgs{ db := &orgs{
DB: dbtest.NewDB(t, "orgs", tables...), DB: dbtest.NewDB(t, "orgs", tables...),
@ -28,7 +29,7 @@ func TestOrgs(t *testing.T) {
for _, tc := range []struct { for _, tc := range []struct {
name string name string
test func(t *testing.T, db *orgs) test func(t *testing.T, ctx context.Context, db *orgs)
}{ }{
{"List", orgsList}, {"List", orgsList},
{"SearchByName", orgsSearchByName}, {"SearchByName", orgsSearchByName},
@ -39,7 +40,7 @@ func TestOrgs(t *testing.T) {
err := clearTables(t, db.DB, tables...) err := clearTables(t, db.DB, tables...)
require.NoError(t, err) require.NoError(t, err)
}) })
tc.test(t, db) tc.test(t, ctx, db)
}) })
if t.Failed() { if t.Failed() {
break break
@ -47,9 +48,7 @@ func TestOrgs(t *testing.T) {
} }
} }
func orgsList(t *testing.T, db *orgs) { func orgsList(t *testing.T, ctx context.Context, db *orgs) {
ctx := context.Background()
usersStore := NewUsersStore(db.DB) usersStore := NewUsersStore(db.DB)
alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -119,9 +118,7 @@ func orgsList(t *testing.T, db *orgs) {
} }
} }
func orgsSearchByName(t *testing.T, db *orgs) { func orgsSearchByName(t *testing.T, ctx context.Context, db *orgs) {
ctx := context.Background()
// TODO: Use Orgs.Create to replace SQL hack when the method is available. // TODO: Use Orgs.Create to replace SQL hack when the method is available.
usersStore := NewUsersStore(db.DB) usersStore := NewUsersStore(db.DB)
org1, err := usersStore.Create(ctx, "org1", "org1@example.com", CreateUserOptions{FullName: "Acme Corp"}) org1, err := usersStore.Create(ctx, "org1", "org1@example.com", CreateUserOptions{FullName: "Acme Corp"})
@ -166,9 +163,7 @@ func orgsSearchByName(t *testing.T, db *orgs) {
}) })
} }
func orgsCountByUser(t *testing.T, db *orgs) { func orgsCountByUser(t *testing.T, ctx context.Context, db *orgs) {
ctx := context.Background()
// TODO: Use Orgs.Join to replace SQL hack when the method is available. // TODO: Use Orgs.Join to replace SQL hack when the method is available.
err := db.Exec(`INSERT INTO org_user (uid, org_id) VALUES (?, ?)`, 1, 1).Error err := db.Exec(`INSERT INTO org_user (uid, org_id) VALUES (?, ?)`, 1, 1).Error
require.NoError(t, err) require.NoError(t, err)

View File

@ -20,6 +20,7 @@ func TestPerms(t *testing.T) {
} }
t.Parallel() t.Parallel()
ctx := context.Background()
tables := []any{new(Access)} tables := []any{new(Access)}
db := &perms{ db := &perms{
DB: dbtest.NewDB(t, "perms", tables...), DB: dbtest.NewDB(t, "perms", tables...),
@ -27,7 +28,7 @@ func TestPerms(t *testing.T) {
for _, tc := range []struct { for _, tc := range []struct {
name string name string
test func(t *testing.T, db *perms) test func(t *testing.T, ctx context.Context, db *perms)
}{ }{
{"AccessMode", permsAccessMode}, {"AccessMode", permsAccessMode},
{"Authorize", permsAuthorize}, {"Authorize", permsAuthorize},
@ -38,7 +39,7 @@ func TestPerms(t *testing.T) {
err := clearTables(t, db.DB, tables...) err := clearTables(t, db.DB, tables...)
require.NoError(t, err) require.NoError(t, err)
}) })
tc.test(t, db) tc.test(t, ctx, db)
}) })
if t.Failed() { if t.Failed() {
break break
@ -46,9 +47,7 @@ func TestPerms(t *testing.T) {
} }
} }
func permsAccessMode(t *testing.T, db *perms) { func permsAccessMode(t *testing.T, ctx context.Context, db *perms) {
ctx := context.Background()
// Set up permissions // Set up permissions
err := db.SetRepoPerms(ctx, 1, err := db.SetRepoPerms(ctx, 1,
map[int64]AccessMode{ map[int64]AccessMode{
@ -159,9 +158,7 @@ func permsAccessMode(t *testing.T, db *perms) {
} }
} }
func permsAuthorize(t *testing.T, db *perms) { func permsAuthorize(t *testing.T, ctx context.Context, db *perms) {
ctx := context.Background()
// Set up permissions // Set up permissions
err := db.SetRepoPerms(ctx, 1, err := db.SetRepoPerms(ctx, 1,
map[int64]AccessMode{ map[int64]AccessMode{
@ -247,9 +244,7 @@ func permsAuthorize(t *testing.T, db *perms) {
} }
} }
func permsSetRepoPerms(t *testing.T, db *perms) { func permsSetRepoPerms(t *testing.T, ctx context.Context, db *perms) {
ctx := context.Background()
for _, update := range []struct { for _, update := range []struct {
repoID int64 repoID int64
accessMap map[int64]AccessMode accessMap map[int64]AccessMode

View File

@ -5,6 +5,7 @@
package db package db
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@ -23,6 +24,7 @@ func TestPublicKeys(t *testing.T) {
} }
t.Parallel() t.Parallel()
ctx := context.Background()
tables := []any{new(PublicKey)} tables := []any{new(PublicKey)}
db := &publicKeys{ db := &publicKeys{
DB: dbtest.NewDB(t, "publicKeys", tables...), DB: dbtest.NewDB(t, "publicKeys", tables...),
@ -30,7 +32,7 @@ func TestPublicKeys(t *testing.T) {
for _, tc := range []struct { for _, tc := range []struct {
name string name string
test func(t *testing.T, db *publicKeys) test func(t *testing.T, ctx context.Context, db *publicKeys)
}{ }{
{"RewriteAuthorizedKeys", publicKeysRewriteAuthorizedKeys}, {"RewriteAuthorizedKeys", publicKeysRewriteAuthorizedKeys},
} { } {
@ -39,7 +41,7 @@ func TestPublicKeys(t *testing.T) {
err := clearTables(t, db.DB, tables...) err := clearTables(t, db.DB, tables...)
require.NoError(t, err) require.NoError(t, err)
}) })
tc.test(t, db) tc.test(t, ctx, db)
}) })
if t.Failed() { if t.Failed() {
break break
@ -47,7 +49,7 @@ func TestPublicKeys(t *testing.T) {
} }
} }
func publicKeysRewriteAuthorizedKeys(t *testing.T, db *publicKeys) { func publicKeysRewriteAuthorizedKeys(t *testing.T, ctx context.Context, db *publicKeys) {
// TODO: Use PublicKeys.Add to replace SQL hack when the method is available. // TODO: Use PublicKeys.Add to replace SQL hack when the method is available.
publicKey := &PublicKey{ publicKey := &PublicKey{
OwnerID: 1, OwnerID: 1,

View File

@ -85,6 +85,7 @@ func TestRepos(t *testing.T) {
} }
t.Parallel() t.Parallel()
ctx := context.Background()
tables := []any{new(Repository), new(Access), new(Watch), new(User), new(EmailAddress), new(Star)} tables := []any{new(Repository), new(Access), new(Watch), new(User), new(EmailAddress), new(Star)}
db := &repos{ db := &repos{
DB: dbtest.NewDB(t, "repos", tables...), DB: dbtest.NewDB(t, "repos", tables...),
@ -92,7 +93,7 @@ func TestRepos(t *testing.T) {
for _, tc := range []struct { for _, tc := range []struct {
name string name string
test func(t *testing.T, db *repos) test func(t *testing.T, ctx context.Context, db *repos)
}{ }{
{"Create", reposCreate}, {"Create", reposCreate},
{"GetByCollaboratorID", reposGetByCollaboratorID}, {"GetByCollaboratorID", reposGetByCollaboratorID},
@ -110,7 +111,7 @@ func TestRepos(t *testing.T) {
err := clearTables(t, db.DB, tables...) err := clearTables(t, db.DB, tables...)
require.NoError(t, err) require.NoError(t, err)
}) })
tc.test(t, db) tc.test(t, ctx, db)
}) })
if t.Failed() { if t.Failed() {
break break
@ -118,9 +119,7 @@ func TestRepos(t *testing.T) {
} }
} }
func reposCreate(t *testing.T, db *repos) { func reposCreate(t *testing.T, ctx context.Context, db *repos) {
ctx := context.Background()
t.Run("name not allowed", func(t *testing.T) { t.Run("name not allowed", func(t *testing.T) {
_, err := db.Create(ctx, _, err := db.Create(ctx,
1, 1,
@ -162,9 +161,7 @@ func reposCreate(t *testing.T, db *repos) {
assert.Equal(t, 1, repo.NumWatches) // The owner is watching the repo by default. assert.Equal(t, 1, repo.NumWatches) // The owner is watching the repo by default.
} }
func reposGetByCollaboratorID(t *testing.T, db *repos) { func reposGetByCollaboratorID(t *testing.T, ctx context.Context, db *repos) {
ctx := context.Background()
repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"}) repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"})
require.NoError(t, err) require.NoError(t, err)
repo2, err := db.Create(ctx, 2, CreateRepoOptions{Name: "repo2"}) repo2, err := db.Create(ctx, 2, CreateRepoOptions{Name: "repo2"})
@ -190,9 +187,7 @@ func reposGetByCollaboratorID(t *testing.T, db *repos) {
}) })
} }
func reposGetByCollaboratorIDWithAccessMode(t *testing.T, db *repos) { func reposGetByCollaboratorIDWithAccessMode(t *testing.T, ctx context.Context, db *repos) {
ctx := context.Background()
repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"}) repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"})
require.NoError(t, err) require.NoError(t, err)
repo2, err := db.Create(ctx, 2, CreateRepoOptions{Name: "repo2"}) repo2, err := db.Create(ctx, 2, CreateRepoOptions{Name: "repo2"})
@ -220,9 +215,7 @@ func reposGetByCollaboratorIDWithAccessMode(t *testing.T, db *repos) {
assert.Equal(t, AccessModeAdmin, accessModes[repo2.ID]) assert.Equal(t, AccessModeAdmin, accessModes[repo2.ID])
} }
func reposGetByID(t *testing.T, db *repos) { func reposGetByID(t *testing.T, ctx context.Context, db *repos) {
ctx := context.Background()
repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"}) repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"})
require.NoError(t, err) require.NoError(t, err)
@ -235,9 +228,7 @@ func reposGetByID(t *testing.T, db *repos) {
assert.Equal(t, wantErr, err) assert.Equal(t, wantErr, err)
} }
func reposGetByName(t *testing.T, db *repos) { func reposGetByName(t *testing.T, ctx context.Context, db *repos) {
ctx := context.Background()
repo, err := db.Create(ctx, 1, repo, err := db.Create(ctx, 1,
CreateRepoOptions{ CreateRepoOptions{
Name: "repo1", Name: "repo1",
@ -253,9 +244,7 @@ func reposGetByName(t *testing.T, db *repos) {
assert.Equal(t, wantErr, err) assert.Equal(t, wantErr, err)
} }
func reposStar(t *testing.T, db *repos) { func reposStar(t *testing.T, ctx context.Context, db *repos) {
ctx := context.Background()
repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"}) repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"})
require.NoError(t, err) require.NoError(t, err)
usersStore := NewUsersStore(db.DB) usersStore := NewUsersStore(db.DB)
@ -274,9 +263,7 @@ func reposStar(t *testing.T, db *repos) {
assert.Equal(t, 1, alice.NumStars) assert.Equal(t, 1, alice.NumStars)
} }
func reposTouch(t *testing.T, db *repos) { func reposTouch(t *testing.T, ctx context.Context, db *repos) {
ctx := context.Background()
repo, err := db.Create(ctx, 1, repo, err := db.Create(ctx, 1,
CreateRepoOptions{ CreateRepoOptions{
Name: "repo1", Name: "repo1",
@ -302,9 +289,7 @@ func reposTouch(t *testing.T, db *repos) {
assert.False(t, got.IsBare) assert.False(t, got.IsBare)
} }
func reposListWatches(t *testing.T, db *repos) { func reposListWatches(t *testing.T, ctx context.Context, db *repos) {
ctx := context.Background()
err := db.Watch(ctx, 1, 1) err := db.Watch(ctx, 1, 1)
require.NoError(t, err) require.NoError(t, err)
err = db.Watch(ctx, 2, 1) err = db.Watch(ctx, 2, 1)
@ -325,9 +310,7 @@ func reposListWatches(t *testing.T, db *repos) {
assert.Equal(t, want, got) assert.Equal(t, want, got)
} }
func reposWatch(t *testing.T, db *repos) { func reposWatch(t *testing.T, ctx context.Context, db *repos) {
ctx := context.Background()
reposStore := NewReposStore(db.DB) reposStore := NewReposStore(db.DB)
repo1, err := reposStore.Create(ctx, 1, CreateRepoOptions{Name: "repo1"}) repo1, err := reposStore.Create(ctx, 1, CreateRepoOptions{Name: "repo1"})
require.NoError(t, err) require.NoError(t, err)
@ -344,9 +327,7 @@ func reposWatch(t *testing.T, db *repos) {
assert.Equal(t, 2, repo1.NumWatches) // The owner is watching the repo by default. assert.Equal(t, 2, repo1.NumWatches) // The owner is watching the repo by default.
} }
func reposHasForkedBy(t *testing.T, db *repos) { func reposHasForkedBy(t *testing.T, ctx context.Context, db *repos) {
ctx := context.Background()
has := db.HasForkedBy(ctx, 1, 2) has := db.HasForkedBy(ctx, 1, 2)
assert.False(t, has) assert.False(t, has)

View File

@ -67,6 +67,7 @@ func TestTwoFactors(t *testing.T) {
} }
t.Parallel() t.Parallel()
ctx := context.Background()
tables := []any{new(TwoFactor), new(TwoFactorRecoveryCode)} tables := []any{new(TwoFactor), new(TwoFactorRecoveryCode)}
db := &twoFactors{ db := &twoFactors{
DB: dbtest.NewDB(t, "twoFactors", tables...), DB: dbtest.NewDB(t, "twoFactors", tables...),
@ -74,7 +75,7 @@ func TestTwoFactors(t *testing.T) {
for _, tc := range []struct { for _, tc := range []struct {
name string name string
test func(t *testing.T, db *twoFactors) test func(t *testing.T, ctx context.Context, db *twoFactors)
}{ }{
{"Create", twoFactorsCreate}, {"Create", twoFactorsCreate},
{"GetByUserID", twoFactorsGetByUserID}, {"GetByUserID", twoFactorsGetByUserID},
@ -85,7 +86,7 @@ func TestTwoFactors(t *testing.T) {
err := clearTables(t, db.DB, tables...) err := clearTables(t, db.DB, tables...)
require.NoError(t, err) require.NoError(t, err)
}) })
tc.test(t, db) tc.test(t, ctx, db)
}) })
if t.Failed() { if t.Failed() {
break break
@ -93,9 +94,7 @@ func TestTwoFactors(t *testing.T) {
} }
} }
func twoFactorsCreate(t *testing.T, db *twoFactors) { func twoFactorsCreate(t *testing.T, ctx context.Context, db *twoFactors) {
ctx := context.Background()
// Create a 2FA token // Create a 2FA token
err := db.Create(ctx, 1, "secure-key", "secure-secret") err := db.Create(ctx, 1, "secure-key", "secure-secret")
require.NoError(t, err) require.NoError(t, err)
@ -112,9 +111,7 @@ func twoFactorsCreate(t *testing.T, db *twoFactors) {
assert.Equal(t, int64(10), count) assert.Equal(t, int64(10), count)
} }
func twoFactorsGetByUserID(t *testing.T, db *twoFactors) { func twoFactorsGetByUserID(t *testing.T, ctx context.Context, db *twoFactors) {
ctx := context.Background()
// Create a 2FA token for user 1 // Create a 2FA token for user 1
err := db.Create(ctx, 1, "secure-key", "secure-secret") err := db.Create(ctx, 1, "secure-key", "secure-secret")
require.NoError(t, err) require.NoError(t, err)
@ -129,9 +126,7 @@ func twoFactorsGetByUserID(t *testing.T, db *twoFactors) {
assert.Equal(t, wantErr, err) assert.Equal(t, wantErr, err)
} }
func twoFactorsIsEnabled(t *testing.T, db *twoFactors) { func twoFactorsIsEnabled(t *testing.T, ctx context.Context, db *twoFactors) {
ctx := context.Background()
// Create a 2FA token for user 1 // Create a 2FA token for user 1
err := db.Create(ctx, 1, "secure-key", "secure-secret") err := db.Create(ctx, 1, "secure-key", "secure-secret")
require.NoError(t, err) require.NoError(t, err)

View File

@ -84,6 +84,7 @@ func TestUsers(t *testing.T) {
} }
t.Parallel() t.Parallel()
ctx := context.Background()
tables := []any{ tables := []any{
new(User), new(EmailAddress), new(Repository), new(Follow), new(PullRequest), new(PublicKey), new(OrgUser), new(User), new(EmailAddress), new(Repository), new(Follow), new(PullRequest), new(PublicKey), new(OrgUser),
new(Watch), new(Star), new(Issue), new(AccessToken), new(Collaboration), new(Action), new(IssueUser), new(Watch), new(Star), new(Issue), new(AccessToken), new(Collaboration), new(Action), new(IssueUser),
@ -95,7 +96,7 @@ func TestUsers(t *testing.T) {
for _, tc := range []struct { for _, tc := range []struct {
name string name string
test func(t *testing.T, db *users) test func(t *testing.T, ctx context.Context, db *users)
}{ }{
{"Authenticate", usersAuthenticate}, {"Authenticate", usersAuthenticate},
{"ChangeUsername", usersChangeUsername}, {"ChangeUsername", usersChangeUsername},
@ -131,7 +132,7 @@ func TestUsers(t *testing.T) {
err := clearTables(t, db.DB, tables...) err := clearTables(t, db.DB, tables...)
require.NoError(t, err) require.NoError(t, err)
}) })
tc.test(t, db) tc.test(t, ctx, db)
}) })
if t.Failed() { if t.Failed() {
break break
@ -139,9 +140,7 @@ func TestUsers(t *testing.T) {
} }
} }
func usersAuthenticate(t *testing.T, db *users) { func usersAuthenticate(t *testing.T, ctx context.Context, db *users) {
ctx := context.Background()
password := "pa$$word" password := "pa$$word"
alice, err := db.Create(ctx, "alice", "alice@example.com", alice, err := db.Create(ctx, "alice", "alice@example.com",
CreateUserOptions{ CreateUserOptions{
@ -236,9 +235,7 @@ func usersAuthenticate(t *testing.T, db *users) {
}) })
} }
func usersChangeUsername(t *testing.T, db *users) { func usersChangeUsername(t *testing.T, ctx context.Context, db *users) {
ctx := context.Background()
alice, err := db.Create( alice, err := db.Create(
ctx, ctx,
"alice", "alice",
@ -368,9 +365,7 @@ func usersChangeUsername(t *testing.T, db *users) {
assert.Equal(t, strings.ToUpper(newUsername), alice.Name) assert.Equal(t, strings.ToUpper(newUsername), alice.Name)
} }
func usersCount(t *testing.T, db *users) { func usersCount(t *testing.T, ctx context.Context, db *users) {
ctx := context.Background()
// Has no user initially // Has no user initially
got := db.Count(ctx) got := db.Count(ctx)
assert.Equal(t, int64(0), got) assert.Equal(t, int64(0), got)
@ -393,9 +388,7 @@ func usersCount(t *testing.T, db *users) {
assert.Equal(t, int64(1), got) assert.Equal(t, int64(1), got)
} }
func usersCreate(t *testing.T, db *users) { func usersCreate(t *testing.T, ctx context.Context, db *users) {
ctx := context.Background()
alice, err := db.Create( alice, err := db.Create(
ctx, ctx,
"alice", "alice",
@ -443,9 +436,7 @@ func usersCreate(t *testing.T, db *users) {
assert.Equal(t, db.NowFunc().Format(time.RFC3339), user.Updated.UTC().Format(time.RFC3339)) assert.Equal(t, db.NowFunc().Format(time.RFC3339), user.Updated.UTC().Format(time.RFC3339))
} }
func usersDeleteCustomAvatar(t *testing.T, db *users) { func usersDeleteCustomAvatar(t *testing.T, ctx context.Context, db *users) {
ctx := context.Background()
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -479,8 +470,7 @@ func usersDeleteCustomAvatar(t *testing.T, db *users) {
assert.False(t, alice.UseCustomAvatar) assert.False(t, alice.UseCustomAvatar)
} }
func usersDeleteByID(t *testing.T, db *users) { func usersDeleteByID(t *testing.T, ctx context.Context, db *users) {
ctx := context.Background()
reposStore := NewReposStore(db.DB) reposStore := NewReposStore(db.DB)
t.Run("user still has repository ownership", func(t *testing.T) { t.Run("user still has repository ownership", func(t *testing.T) {
@ -690,9 +680,7 @@ func usersDeleteByID(t *testing.T, db *users) {
assert.Equal(t, wantErr, err) assert.Equal(t, wantErr, err)
} }
func usersDeleteInactivated(t *testing.T, db *users) { func usersDeleteInactivated(t *testing.T, ctx context.Context, db *users) {
ctx := context.Background()
// User with repository ownership should be skipped // User with repository ownership should be skipped
alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{}) alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -738,9 +726,7 @@ func usersDeleteInactivated(t *testing.T, db *users) {
require.Len(t, users, 3) require.Len(t, users, 3)
} }
func usersGetByEmail(t *testing.T, db *users) { func usersGetByEmail(t *testing.T, ctx context.Context, db *users) {
ctx := context.Background()
t.Run("empty email", func(t *testing.T) { t.Run("empty email", func(t *testing.T) {
_, err := db.GetByEmail(ctx, "") _, err := db.GetByEmail(ctx, "")
wantErr := ErrUserNotExist{args: errutil.Args{"email": ""}} wantErr := ErrUserNotExist{args: errutil.Args{"email": ""}}
@ -801,9 +787,7 @@ func usersGetByEmail(t *testing.T, db *users) {
}) })
} }
func usersGetByID(t *testing.T, db *users) { func usersGetByID(t *testing.T, ctx context.Context, db *users) {
ctx := context.Background()
alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{}) alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -816,9 +800,7 @@ func usersGetByID(t *testing.T, db *users) {
assert.Equal(t, wantErr, err) assert.Equal(t, wantErr, err)
} }
func usersGetByUsername(t *testing.T, db *users) { func usersGetByUsername(t *testing.T, ctx context.Context, db *users) {
ctx := context.Background()
alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{}) alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -831,9 +813,7 @@ func usersGetByUsername(t *testing.T, db *users) {
assert.Equal(t, wantErr, err) assert.Equal(t, wantErr, err)
} }
func usersGetByKeyID(t *testing.T, db *users) { func usersGetByKeyID(t *testing.T, ctx context.Context, db *users) {
ctx := context.Background()
alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{}) alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -858,9 +838,7 @@ func usersGetByKeyID(t *testing.T, db *users) {
assert.Equal(t, wantErr, err) assert.Equal(t, wantErr, err)
} }
func usersGetMailableEmailsByUsernames(t *testing.T, db *users) { func usersGetMailableEmailsByUsernames(t *testing.T, ctx context.Context, db *users) {
ctx := context.Background()
alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{}) alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
bob, err := db.Create(ctx, "bob", "bob@exmaple.com", CreateUserOptions{Activated: true}) bob, err := db.Create(ctx, "bob", "bob@exmaple.com", CreateUserOptions{Activated: true})
@ -874,9 +852,7 @@ func usersGetMailableEmailsByUsernames(t *testing.T, db *users) {
assert.Equal(t, want, got) assert.Equal(t, want, got)
} }
func usersIsUsernameUsed(t *testing.T, db *users) { func usersIsUsernameUsed(t *testing.T, ctx context.Context, db *users) {
ctx := context.Background()
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -926,9 +902,7 @@ func usersIsUsernameUsed(t *testing.T, db *users) {
} }
} }
func usersList(t *testing.T, db *users) { func usersList(t *testing.T, ctx context.Context, db *users) {
ctx := context.Background()
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
bob, err := db.Create(ctx, "bob", "bob@example.com", CreateUserOptions{}) bob, err := db.Create(ctx, "bob", "bob@example.com", CreateUserOptions{})
@ -961,9 +935,7 @@ func usersList(t *testing.T, db *users) {
assert.Equal(t, bob.ID, got[1].ID) assert.Equal(t, bob.ID, got[1].ID)
} }
func usersListFollowers(t *testing.T, db *users) { func usersListFollowers(t *testing.T, ctx context.Context, db *users) {
ctx := context.Background()
john, err := db.Create(ctx, "john", "john@example.com", CreateUserOptions{}) john, err := db.Create(ctx, "john", "john@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -994,9 +966,7 @@ func usersListFollowers(t *testing.T, db *users) {
assert.Equal(t, alice.ID, got[0].ID) assert.Equal(t, alice.ID, got[0].ID)
} }
func usersListFollowings(t *testing.T, db *users) { func usersListFollowings(t *testing.T, ctx context.Context, db *users) {
ctx := context.Background()
john, err := db.Create(ctx, "john", "john@example.com", CreateUserOptions{}) john, err := db.Create(ctx, "john", "john@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -1027,9 +997,7 @@ func usersListFollowings(t *testing.T, db *users) {
assert.Equal(t, alice.ID, got[0].ID) assert.Equal(t, alice.ID, got[0].ID)
} }
func usersSearchByName(t *testing.T, db *users) { func usersSearchByName(t *testing.T, ctx context.Context, db *users) {
ctx := context.Background()
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{FullName: "Alice Jordan"}) alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{FullName: "Alice Jordan"})
require.NoError(t, err) require.NoError(t, err)
bob, err := db.Create(ctx, "bob", "bob@example.com", CreateUserOptions{FullName: "Bob Jordan"}) bob, err := db.Create(ctx, "bob", "bob@example.com", CreateUserOptions{FullName: "Bob Jordan"})
@ -1067,9 +1035,7 @@ func usersSearchByName(t *testing.T, db *users) {
}) })
} }
func usersUpdate(t *testing.T, db *users) { func usersUpdate(t *testing.T, ctx context.Context, db *users) {
ctx := context.Background()
const oldPassword = "Password" const oldPassword = "Password"
alice, err := db.Create( alice, err := db.Create(
ctx, ctx,
@ -1182,9 +1148,7 @@ func usersUpdate(t *testing.T, db *users) {
assertValues() assertValues()
} }
func usersUseCustomAvatar(t *testing.T, db *users) { func usersUseCustomAvatar(t *testing.T, ctx context.Context, db *users) {
ctx := context.Background()
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -1222,9 +1186,7 @@ func TestIsUsernameAllowed(t *testing.T) {
} }
} }
func usersAddEmail(t *testing.T, db *users) { func usersAddEmail(t *testing.T, ctx context.Context, db *users) {
ctx := context.Background()
t.Run("multiple users can add the same unverified email", func(t *testing.T) { t.Run("multiple users can add the same unverified email", func(t *testing.T) {
alice, err := db.Create(ctx, "alice", "unverified@example.com", CreateUserOptions{}) alice, err := db.Create(ctx, "alice", "unverified@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -1241,9 +1203,7 @@ func usersAddEmail(t *testing.T, db *users) {
}) })
} }
func usersGetEmail(t *testing.T, db *users) { func usersGetEmail(t *testing.T, ctx context.Context, db *users) {
ctx := context.Background()
const testUserID = 1 const testUserID = 1
const testEmail = "alice@example.com" const testEmail = "alice@example.com"
_, err := db.GetEmail(ctx, testUserID, testEmail, false) _, err := db.GetEmail(ctx, testUserID, testEmail, false)
@ -1275,9 +1235,7 @@ func usersGetEmail(t *testing.T, db *users) {
assert.Equal(t, testEmail, got.Email) assert.Equal(t, testEmail, got.Email)
} }
func usersListEmails(t *testing.T, db *users) { func usersListEmails(t *testing.T, ctx context.Context, db *users) {
ctx := context.Background()
t.Run("list emails with primary email", func(t *testing.T) { t.Run("list emails with primary email", func(t *testing.T) {
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -1313,9 +1271,7 @@ func usersListEmails(t *testing.T, db *users) {
}) })
} }
func usersMarkEmailActivated(t *testing.T, db *users) { func usersMarkEmailActivated(t *testing.T, ctx context.Context, db *users) {
ctx := context.Background()
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -1333,8 +1289,7 @@ func usersMarkEmailActivated(t *testing.T, db *users) {
assert.NotEqual(t, alice.Rands, gotAlice.Rands) assert.NotEqual(t, alice.Rands, gotAlice.Rands)
} }
func usersMarkEmailPrimary(t *testing.T, db *users) { func usersMarkEmailPrimary(t *testing.T, ctx context.Context, db *users) {
ctx := context.Background()
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
err = db.AddEmail(ctx, alice.ID, "alice2@example.com", false) err = db.AddEmail(ctx, alice.ID, "alice2@example.com", false)
@ -1360,8 +1315,7 @@ func usersMarkEmailPrimary(t *testing.T, db *users) {
assert.False(t, gotEmail.IsActivated) assert.False(t, gotEmail.IsActivated)
} }
func usersDeleteEmail(t *testing.T, db *users) { func usersDeleteEmail(t *testing.T, ctx context.Context, db *users) {
ctx := context.Background()
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -1377,9 +1331,7 @@ func usersDeleteEmail(t *testing.T, db *users) {
require.Equal(t, want, got) require.Equal(t, want, got)
} }
func usersFollow(t *testing.T, db *users) { func usersFollow(t *testing.T, ctx context.Context, db *users) {
ctx := context.Background()
usersStore := NewUsersStore(db.DB) usersStore := NewUsersStore(db.DB)
alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -1402,9 +1354,7 @@ func usersFollow(t *testing.T, db *users) {
assert.Equal(t, 1, bob.NumFollowers) assert.Equal(t, 1, bob.NumFollowers)
} }
func usersIsFollowing(t *testing.T, db *users) { func usersIsFollowing(t *testing.T, ctx context.Context, db *users) {
ctx := context.Background()
usersStore := NewUsersStore(db.DB) usersStore := NewUsersStore(db.DB)
alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -1425,9 +1375,7 @@ func usersIsFollowing(t *testing.T, db *users) {
assert.False(t, got) assert.False(t, got)
} }
func usersUnfollow(t *testing.T, db *users) { func usersUnfollow(t *testing.T, ctx context.Context, db *users) {
ctx := context.Background()
usersStore := NewUsersStore(db.DB) usersStore := NewUsersStore(db.DB)
alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err) require.NoError(t, err)