db: pass context to tests by default (#7622)

[skip ci]
pull/7623/head
Joe Chen 2023-12-17 16:32:28 -05:00 committed by GitHub
parent 0c7b45ad1f
commit 25fdeaac49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 114 additions and 255 deletions

View File

@ -98,6 +98,7 @@ func TestAccessTokens(t *testing.T) {
}
t.Parallel()
ctx := context.Background()
tables := []any{new(AccessToken)}
db := &accessTokens{
DB: dbtest.NewDB(t, "accessTokens", tables...),
@ -105,7 +106,7 @@ func TestAccessTokens(t *testing.T) {
for _, tc := range []struct {
name string
test func(t *testing.T, db *accessTokens)
test func(t *testing.T, ctx context.Context, db *accessTokens)
}{
{"Create", accessTokensCreate},
{"DeleteByID", accessTokensDeleteByID},
@ -118,7 +119,7 @@ func TestAccessTokens(t *testing.T) {
err := clearTables(t, db.DB, tables...)
require.NoError(t, err)
})
tc.test(t, db)
tc.test(t, ctx, db)
})
if t.Failed() {
break
@ -126,9 +127,7 @@ func TestAccessTokens(t *testing.T) {
}
}
func accessTokensCreate(t *testing.T, db *accessTokens) {
ctx := context.Background()
func accessTokensCreate(t *testing.T, ctx context.Context, db *accessTokens) {
// Create first access token with name "Test"
token, err := db.Create(ctx, 1, "Test")
require.NoError(t, err)
@ -153,9 +152,7 @@ func accessTokensCreate(t *testing.T, db *accessTokens) {
assert.Equal(t, wantErr, err)
}
func accessTokensDeleteByID(t *testing.T, db *accessTokens) {
ctx := context.Background()
func accessTokensDeleteByID(t *testing.T, ctx context.Context, db *accessTokens) {
// Create an access token with name "Test"
token, err := db.Create(ctx, 1, "Test")
require.NoError(t, err)
@ -182,9 +179,7 @@ func accessTokensDeleteByID(t *testing.T, db *accessTokens) {
assert.Equal(t, wantErr, err)
}
func accessTokensGetBySHA(t *testing.T, db *accessTokens) {
ctx := context.Background()
func accessTokensGetBySHA(t *testing.T, ctx context.Context, db *accessTokens) {
// Create an access token with name "Test"
token, err := db.Create(ctx, 1, "Test")
require.NoError(t, err)
@ -203,9 +198,7 @@ func accessTokensGetBySHA(t *testing.T, db *accessTokens) {
assert.Equal(t, wantErr, err)
}
func accessTokensList(t *testing.T, db *accessTokens) {
ctx := context.Background()
func accessTokensList(t *testing.T, ctx context.Context, db *accessTokens) {
// Create two access tokens for user 1
_, err := db.Create(ctx, 1, "user1_1")
require.NoError(t, err)
@ -228,9 +221,7 @@ func accessTokensList(t *testing.T, db *accessTokens) {
assert.Equal(t, "user1_2", tokens[1].Name)
}
func accessTokensTouch(t *testing.T, db *accessTokens) {
ctx := context.Background()
func accessTokensTouch(t *testing.T, ctx context.Context, db *accessTokens) {
// Create an access token with name "Test"
token, err := db.Create(ctx, 1, "Test")
require.NoError(t, err)

View File

@ -97,8 +97,9 @@ func TestActions(t *testing.T) {
if testing.Short() {
t.Skip()
}
t.Parallel()
ctx := context.Background()
t.Parallel()
tables := []any{new(Action), new(User), new(Repository), new(EmailAddress), new(Watch)}
db := &actions{
DB: dbtest.NewDB(t, "actions", tables...),
@ -106,7 +107,7 @@ func TestActions(t *testing.T) {
for _, tc := range []struct {
name string
test func(t *testing.T, db *actions)
test func(t *testing.T, ctx context.Context, db *actions)
}{
{"CommitRepo", actionsCommitRepo},
{"ListByOrganization", actionsListByOrganization},
@ -125,7 +126,7 @@ func TestActions(t *testing.T) {
err := clearTables(t, db.DB, tables...)
require.NoError(t, err)
})
tc.test(t, db)
tc.test(t, ctx, db)
})
if t.Failed() {
break
@ -133,9 +134,7 @@ func TestActions(t *testing.T) {
}
}
func actionsCommitRepo(t *testing.T, db *actions) {
ctx := context.Background()
func actionsCommitRepo(t *testing.T, ctx context.Context, db *actions) {
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
repo, err := NewReposStore(db.DB).Create(ctx,
@ -327,14 +326,12 @@ func actionsCommitRepo(t *testing.T, db *actions) {
})
}
func actionsListByOrganization(t *testing.T, db *actions) {
func actionsListByOrganization(t *testing.T, ctx context.Context, db *actions) {
if os.Getenv("GOGS_DATABASE_TYPE") != "postgres" {
t.Skip("Skipping testing with not using PostgreSQL")
return
}
ctx := context.Background()
conf.SetMockUI(t,
conf.UIOpts{
User: conf.UIUserOpts{
@ -375,14 +372,12 @@ func actionsListByOrganization(t *testing.T, db *actions) {
}
}
func actionsListByUser(t *testing.T, db *actions) {
func actionsListByUser(t *testing.T, ctx context.Context, db *actions) {
if os.Getenv("GOGS_DATABASE_TYPE") != "postgres" {
t.Skip("Skipping testing with not using PostgreSQL")
return
}
ctx := context.Background()
conf.SetMockUI(t,
conf.UIOpts{
User: conf.UIUserOpts{
@ -442,9 +437,7 @@ func actionsListByUser(t *testing.T, db *actions) {
}
}
func actionsMergePullRequest(t *testing.T, db *actions) {
ctx := context.Background()
func actionsMergePullRequest(t *testing.T, ctx context.Context, db *actions) {
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
repo, err := NewReposStore(db.DB).Create(ctx,
@ -489,9 +482,7 @@ func actionsMergePullRequest(t *testing.T, db *actions) {
assert.Equal(t, want, got)
}
func actionsMirrorSyncCreate(t *testing.T, db *actions) {
ctx := context.Background()
func actionsMirrorSyncCreate(t *testing.T, ctx context.Context, db *actions) {
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
repo, err := NewReposStore(db.DB).Create(ctx,
@ -532,9 +523,7 @@ func actionsMirrorSyncCreate(t *testing.T, db *actions) {
assert.Equal(t, want, got)
}
func actionsMirrorSyncDelete(t *testing.T, db *actions) {
ctx := context.Background()
func actionsMirrorSyncDelete(t *testing.T, ctx context.Context, db *actions) {
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
repo, err := NewReposStore(db.DB).Create(ctx,
@ -575,9 +564,7 @@ func actionsMirrorSyncDelete(t *testing.T, db *actions) {
assert.Equal(t, want, got)
}
func actionsMirrorSyncPush(t *testing.T, db *actions) {
ctx := context.Background()
func actionsMirrorSyncPush(t *testing.T, ctx context.Context, db *actions) {
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
repo, err := NewReposStore(db.DB).Create(ctx,
@ -642,9 +629,7 @@ func actionsMirrorSyncPush(t *testing.T, db *actions) {
assert.Equal(t, want, got)
}
func actionsNewRepo(t *testing.T, db *actions) {
ctx := context.Background()
func actionsNewRepo(t *testing.T, ctx context.Context, db *actions) {
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
repo, err := NewReposStore(db.DB).Create(ctx,
@ -719,9 +704,7 @@ func actionsNewRepo(t *testing.T, db *actions) {
})
}
func actionsPushTag(t *testing.T, db *actions) {
ctx := context.Background()
func actionsPushTag(t *testing.T, ctx context.Context, db *actions) {
// NOTE: We set a noop mock here to avoid data race with other tests that writes
// to the mock server because this function holds a lock.
conf.SetMockServer(t, conf.ServerOpts{})
@ -817,9 +800,7 @@ func actionsPushTag(t *testing.T, db *actions) {
})
}
func actionsRenameRepo(t *testing.T, db *actions) {
ctx := context.Background()
func actionsRenameRepo(t *testing.T, ctx context.Context, db *actions) {
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
repo, err := NewReposStore(db.DB).Create(ctx,
@ -856,9 +837,7 @@ func actionsRenameRepo(t *testing.T, db *actions) {
assert.Equal(t, want, got)
}
func actionsTransferRepo(t *testing.T, db *actions) {
ctx := context.Background()
func actionsTransferRepo(t *testing.T, ctx context.Context, db *actions) {
alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
bob, err := NewUsersStore(db.DB).Create(ctx, "bob", "bob@example.com", CreateUserOptions{})

View File

@ -23,6 +23,7 @@ func TestLFS(t *testing.T) {
}
t.Parallel()
ctx := context.Background()
tables := []any{new(LFSObject)}
db := &lfs{
DB: dbtest.NewDB(t, "lfs", tables...),
@ -30,7 +31,7 @@ func TestLFS(t *testing.T) {
for _, tc := range []struct {
name string
test func(t *testing.T, db *lfs)
test func(t *testing.T, ctx context.Context, db *lfs)
}{
{"CreateObject", lfsCreateObject},
{"GetObjectByOID", lfsGetObjectByOID},
@ -41,7 +42,7 @@ func TestLFS(t *testing.T) {
err := clearTables(t, db.DB, tables...)
require.NoError(t, err)
})
tc.test(t, db)
tc.test(t, ctx, db)
})
if t.Failed() {
break
@ -49,9 +50,7 @@ func TestLFS(t *testing.T) {
}
}
func lfsCreateObject(t *testing.T, db *lfs) {
ctx := context.Background()
func lfsCreateObject(t *testing.T, ctx context.Context, db *lfs) {
// Create first LFS object
repoID := int64(1)
oid := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f")
@ -68,9 +67,7 @@ func lfsCreateObject(t *testing.T, db *lfs) {
assert.Error(t, err)
}
func lfsGetObjectByOID(t *testing.T, db *lfs) {
ctx := context.Background()
func lfsGetObjectByOID(t *testing.T, ctx context.Context, db *lfs) {
// Create a LFS object
repoID := int64(1)
oid := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f")
@ -87,9 +84,7 @@ func lfsGetObjectByOID(t *testing.T, db *lfs) {
assert.Equal(t, expErr, err)
}
func lfsGetObjectsByOIDs(t *testing.T, db *lfs) {
ctx := context.Background()
func lfsGetObjectsByOIDs(t *testing.T, ctx context.Context, db *lfs) {
// Create two LFS objects
repoID := int64(1)
oid1 := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f")

View File

@ -163,6 +163,7 @@ func TestLoginSources(t *testing.T) {
}
t.Parallel()
ctx := context.Background()
tables := []any{new(LoginSource), new(User)}
db := &loginSources{
DB: dbtest.NewDB(t, "loginSources", tables...),
@ -170,7 +171,7 @@ func TestLoginSources(t *testing.T) {
for _, tc := range []struct {
name string
test func(t *testing.T, db *loginSources)
test func(t *testing.T, ctx context.Context, db *loginSources)
}{
{"Create", loginSourcesCreate},
{"Count", loginSourcesCount},
@ -185,7 +186,7 @@ func TestLoginSources(t *testing.T) {
err := clearTables(t, db.DB, tables...)
require.NoError(t, err)
})
tc.test(t, db)
tc.test(t, ctx, db)
})
if t.Failed() {
break
@ -193,9 +194,7 @@ func TestLoginSources(t *testing.T) {
}
}
func loginSourcesCreate(t *testing.T, db *loginSources) {
ctx := context.Background()
func loginSourcesCreate(t *testing.T, ctx context.Context, db *loginSources) {
// Create first login source with name "GitHub"
source, err := db.Create(ctx,
CreateLoginSourceOptions{
@ -222,9 +221,7 @@ func loginSourcesCreate(t *testing.T, db *loginSources) {
assert.Equal(t, wantErr, err)
}
func loginSourcesCount(t *testing.T, db *loginSources) {
ctx := context.Background()
func loginSourcesCount(t *testing.T, ctx context.Context, db *loginSources) {
// Create two login sources, one in database and one as source file.
_, err := db.Create(ctx,
CreateLoginSourceOptions{
@ -246,9 +243,7 @@ func loginSourcesCount(t *testing.T, db *loginSources) {
assert.Equal(t, int64(3), db.Count(ctx))
}
func loginSourcesDeleteByID(t *testing.T, db *loginSources) {
ctx := context.Background()
func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSources) {
t.Run("delete but in used", func(t *testing.T) {
source, err := db.Create(ctx,
CreateLoginSourceOptions{
@ -315,9 +310,7 @@ func loginSourcesDeleteByID(t *testing.T, db *loginSources) {
assert.Equal(t, wantErr, err)
}
func loginSourcesGetByID(t *testing.T, db *loginSources) {
ctx := context.Background()
func loginSourcesGetByID(t *testing.T, ctx context.Context, db *loginSources) {
mock := NewMockLoginSourceFilesStore()
mock.GetByIDFunc.SetDefaultHook(func(id int64) (*LoginSource, error) {
if id != 101 {
@ -353,9 +346,7 @@ func loginSourcesGetByID(t *testing.T, db *loginSources) {
require.NoError(t, err)
}
func loginSourcesList(t *testing.T, db *loginSources) {
ctx := context.Background()
func loginSourcesList(t *testing.T, ctx context.Context, db *loginSources) {
mock := NewMockLoginSourceFilesStore()
mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource {
if opts.OnlyActivated {
@ -404,9 +395,7 @@ func loginSourcesList(t *testing.T, db *loginSources) {
assert.Equal(t, 2, len(sources), "number of sources")
}
func loginSourcesResetNonDefault(t *testing.T, db *loginSources) {
ctx := context.Background()
func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSources) {
mock := NewMockLoginSourceFilesStore()
mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource {
mockFile := NewMockLoginSourceFileStore()
@ -461,9 +450,7 @@ func loginSourcesResetNonDefault(t *testing.T, db *loginSources) {
assert.False(t, source2.IsDefault)
}
func loginSourcesSave(t *testing.T, db *loginSources) {
ctx := context.Background()
func loginSourcesSave(t *testing.T, ctx context.Context, db *loginSources) {
t.Run("save to database", func(t *testing.T) {
// Create a login source with name "GitHub"
source, err := db.Create(ctx,

View File

@ -66,6 +66,7 @@ func TestNotices(t *testing.T) {
}
t.Parallel()
ctx := context.Background()
tables := []any{new(Notice)}
db := &notices{
DB: dbtest.NewDB(t, "notices", tables...),
@ -73,7 +74,7 @@ func TestNotices(t *testing.T) {
for _, tc := range []struct {
name string
test func(t *testing.T, db *notices)
test func(t *testing.T, ctx context.Context, db *notices)
}{
{"Create", noticesCreate},
{"DeleteByIDs", noticesDeleteByIDs},
@ -86,7 +87,7 @@ func TestNotices(t *testing.T) {
err := clearTables(t, db.DB, tables...)
require.NoError(t, err)
})
tc.test(t, db)
tc.test(t, ctx, db)
})
if t.Failed() {
break
@ -94,9 +95,7 @@ func TestNotices(t *testing.T) {
}
}
func noticesCreate(t *testing.T, db *notices) {
ctx := context.Background()
func noticesCreate(t *testing.T, ctx context.Context, db *notices) {
err := db.Create(ctx, NoticeTypeRepository, "test")
require.NoError(t, err)
@ -104,9 +103,7 @@ func noticesCreate(t *testing.T, db *notices) {
assert.Equal(t, int64(1), count)
}
func noticesDeleteByIDs(t *testing.T, db *notices) {
ctx := context.Background()
func noticesDeleteByIDs(t *testing.T, ctx context.Context, db *notices) {
err := db.Create(ctx, NoticeTypeRepository, "test")
require.NoError(t, err)
@ -126,9 +123,7 @@ func noticesDeleteByIDs(t *testing.T, db *notices) {
assert.Equal(t, int64(0), count)
}
func noticesDeleteAll(t *testing.T, db *notices) {
ctx := context.Background()
func noticesDeleteAll(t *testing.T, ctx context.Context, db *notices) {
err := db.Create(ctx, NoticeTypeRepository, "test")
require.NoError(t, err)
@ -139,9 +134,7 @@ func noticesDeleteAll(t *testing.T, db *notices) {
assert.Equal(t, int64(0), count)
}
func noticesList(t *testing.T, db *notices) {
ctx := context.Background()
func noticesList(t *testing.T, ctx context.Context, db *notices) {
err := db.Create(ctx, NoticeTypeRepository, "test 1")
require.NoError(t, err)
err = db.Create(ctx, NoticeTypeRepository, "test 2")
@ -161,9 +154,7 @@ func noticesList(t *testing.T, db *notices) {
require.Len(t, got, 2)
}
func noticesCount(t *testing.T, db *notices) {
ctx := context.Background()
func noticesCount(t *testing.T, ctx context.Context, db *notices) {
count := db.Count(ctx)
assert.Equal(t, int64(0), count)

View File

@ -21,6 +21,7 @@ func TestOrgs(t *testing.T) {
}
t.Parallel()
ctx := context.Background()
tables := []any{new(User), new(EmailAddress), new(OrgUser)}
db := &orgs{
DB: dbtest.NewDB(t, "orgs", tables...),
@ -28,7 +29,7 @@ func TestOrgs(t *testing.T) {
for _, tc := range []struct {
name string
test func(t *testing.T, db *orgs)
test func(t *testing.T, ctx context.Context, db *orgs)
}{
{"List", orgsList},
{"SearchByName", orgsSearchByName},
@ -39,7 +40,7 @@ func TestOrgs(t *testing.T) {
err := clearTables(t, db.DB, tables...)
require.NoError(t, err)
})
tc.test(t, db)
tc.test(t, ctx, db)
})
if t.Failed() {
break
@ -47,9 +48,7 @@ func TestOrgs(t *testing.T) {
}
}
func orgsList(t *testing.T, db *orgs) {
ctx := context.Background()
func orgsList(t *testing.T, ctx context.Context, db *orgs) {
usersStore := NewUsersStore(db.DB)
alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
@ -119,9 +118,7 @@ func orgsList(t *testing.T, db *orgs) {
}
}
func orgsSearchByName(t *testing.T, db *orgs) {
ctx := context.Background()
func orgsSearchByName(t *testing.T, ctx context.Context, db *orgs) {
// TODO: Use Orgs.Create to replace SQL hack when the method is available.
usersStore := NewUsersStore(db.DB)
org1, err := usersStore.Create(ctx, "org1", "org1@example.com", CreateUserOptions{FullName: "Acme Corp"})
@ -166,9 +163,7 @@ func orgsSearchByName(t *testing.T, db *orgs) {
})
}
func orgsCountByUser(t *testing.T, db *orgs) {
ctx := context.Background()
func orgsCountByUser(t *testing.T, ctx context.Context, db *orgs) {
// TODO: Use Orgs.Join to replace SQL hack when the method is available.
err := db.Exec(`INSERT INTO org_user (uid, org_id) VALUES (?, ?)`, 1, 1).Error
require.NoError(t, err)

View File

@ -20,6 +20,7 @@ func TestPerms(t *testing.T) {
}
t.Parallel()
ctx := context.Background()
tables := []any{new(Access)}
db := &perms{
DB: dbtest.NewDB(t, "perms", tables...),
@ -27,7 +28,7 @@ func TestPerms(t *testing.T) {
for _, tc := range []struct {
name string
test func(t *testing.T, db *perms)
test func(t *testing.T, ctx context.Context, db *perms)
}{
{"AccessMode", permsAccessMode},
{"Authorize", permsAuthorize},
@ -38,7 +39,7 @@ func TestPerms(t *testing.T) {
err := clearTables(t, db.DB, tables...)
require.NoError(t, err)
})
tc.test(t, db)
tc.test(t, ctx, db)
})
if t.Failed() {
break
@ -46,9 +47,7 @@ func TestPerms(t *testing.T) {
}
}
func permsAccessMode(t *testing.T, db *perms) {
ctx := context.Background()
func permsAccessMode(t *testing.T, ctx context.Context, db *perms) {
// Set up permissions
err := db.SetRepoPerms(ctx, 1,
map[int64]AccessMode{
@ -159,9 +158,7 @@ func permsAccessMode(t *testing.T, db *perms) {
}
}
func permsAuthorize(t *testing.T, db *perms) {
ctx := context.Background()
func permsAuthorize(t *testing.T, ctx context.Context, db *perms) {
// Set up permissions
err := db.SetRepoPerms(ctx, 1,
map[int64]AccessMode{
@ -247,9 +244,7 @@ func permsAuthorize(t *testing.T, db *perms) {
}
}
func permsSetRepoPerms(t *testing.T, db *perms) {
ctx := context.Background()
func permsSetRepoPerms(t *testing.T, ctx context.Context, db *perms) {
for _, update := range []struct {
repoID int64
accessMap map[int64]AccessMode

View File

@ -5,6 +5,7 @@
package db
import (
"context"
"fmt"
"os"
"path/filepath"
@ -23,6 +24,7 @@ func TestPublicKeys(t *testing.T) {
}
t.Parallel()
ctx := context.Background()
tables := []any{new(PublicKey)}
db := &publicKeys{
DB: dbtest.NewDB(t, "publicKeys", tables...),
@ -30,7 +32,7 @@ func TestPublicKeys(t *testing.T) {
for _, tc := range []struct {
name string
test func(t *testing.T, db *publicKeys)
test func(t *testing.T, ctx context.Context, db *publicKeys)
}{
{"RewriteAuthorizedKeys", publicKeysRewriteAuthorizedKeys},
} {
@ -39,7 +41,7 @@ func TestPublicKeys(t *testing.T) {
err := clearTables(t, db.DB, tables...)
require.NoError(t, err)
})
tc.test(t, db)
tc.test(t, ctx, db)
})
if t.Failed() {
break
@ -47,7 +49,7 @@ func TestPublicKeys(t *testing.T) {
}
}
func publicKeysRewriteAuthorizedKeys(t *testing.T, db *publicKeys) {
func publicKeysRewriteAuthorizedKeys(t *testing.T, ctx context.Context, db *publicKeys) {
// TODO: Use PublicKeys.Add to replace SQL hack when the method is available.
publicKey := &PublicKey{
OwnerID: 1,

View File

@ -85,6 +85,7 @@ func TestRepos(t *testing.T) {
}
t.Parallel()
ctx := context.Background()
tables := []any{new(Repository), new(Access), new(Watch), new(User), new(EmailAddress), new(Star)}
db := &repos{
DB: dbtest.NewDB(t, "repos", tables...),
@ -92,7 +93,7 @@ func TestRepos(t *testing.T) {
for _, tc := range []struct {
name string
test func(t *testing.T, db *repos)
test func(t *testing.T, ctx context.Context, db *repos)
}{
{"Create", reposCreate},
{"GetByCollaboratorID", reposGetByCollaboratorID},
@ -110,7 +111,7 @@ func TestRepos(t *testing.T) {
err := clearTables(t, db.DB, tables...)
require.NoError(t, err)
})
tc.test(t, db)
tc.test(t, ctx, db)
})
if t.Failed() {
break
@ -118,9 +119,7 @@ func TestRepos(t *testing.T) {
}
}
func reposCreate(t *testing.T, db *repos) {
ctx := context.Background()
func reposCreate(t *testing.T, ctx context.Context, db *repos) {
t.Run("name not allowed", func(t *testing.T) {
_, err := db.Create(ctx,
1,
@ -162,9 +161,7 @@ func reposCreate(t *testing.T, db *repos) {
assert.Equal(t, 1, repo.NumWatches) // The owner is watching the repo by default.
}
func reposGetByCollaboratorID(t *testing.T, db *repos) {
ctx := context.Background()
func reposGetByCollaboratorID(t *testing.T, ctx context.Context, db *repos) {
repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"})
require.NoError(t, err)
repo2, err := db.Create(ctx, 2, CreateRepoOptions{Name: "repo2"})
@ -190,9 +187,7 @@ func reposGetByCollaboratorID(t *testing.T, db *repos) {
})
}
func reposGetByCollaboratorIDWithAccessMode(t *testing.T, db *repos) {
ctx := context.Background()
func reposGetByCollaboratorIDWithAccessMode(t *testing.T, ctx context.Context, db *repos) {
repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"})
require.NoError(t, err)
repo2, err := db.Create(ctx, 2, CreateRepoOptions{Name: "repo2"})
@ -220,9 +215,7 @@ func reposGetByCollaboratorIDWithAccessMode(t *testing.T, db *repos) {
assert.Equal(t, AccessModeAdmin, accessModes[repo2.ID])
}
func reposGetByID(t *testing.T, db *repos) {
ctx := context.Background()
func reposGetByID(t *testing.T, ctx context.Context, db *repos) {
repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"})
require.NoError(t, err)
@ -235,9 +228,7 @@ func reposGetByID(t *testing.T, db *repos) {
assert.Equal(t, wantErr, err)
}
func reposGetByName(t *testing.T, db *repos) {
ctx := context.Background()
func reposGetByName(t *testing.T, ctx context.Context, db *repos) {
repo, err := db.Create(ctx, 1,
CreateRepoOptions{
Name: "repo1",
@ -253,9 +244,7 @@ func reposGetByName(t *testing.T, db *repos) {
assert.Equal(t, wantErr, err)
}
func reposStar(t *testing.T, db *repos) {
ctx := context.Background()
func reposStar(t *testing.T, ctx context.Context, db *repos) {
repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"})
require.NoError(t, err)
usersStore := NewUsersStore(db.DB)
@ -274,9 +263,7 @@ func reposStar(t *testing.T, db *repos) {
assert.Equal(t, 1, alice.NumStars)
}
func reposTouch(t *testing.T, db *repos) {
ctx := context.Background()
func reposTouch(t *testing.T, ctx context.Context, db *repos) {
repo, err := db.Create(ctx, 1,
CreateRepoOptions{
Name: "repo1",
@ -302,9 +289,7 @@ func reposTouch(t *testing.T, db *repos) {
assert.False(t, got.IsBare)
}
func reposListWatches(t *testing.T, db *repos) {
ctx := context.Background()
func reposListWatches(t *testing.T, ctx context.Context, db *repos) {
err := db.Watch(ctx, 1, 1)
require.NoError(t, err)
err = db.Watch(ctx, 2, 1)
@ -325,9 +310,7 @@ func reposListWatches(t *testing.T, db *repos) {
assert.Equal(t, want, got)
}
func reposWatch(t *testing.T, db *repos) {
ctx := context.Background()
func reposWatch(t *testing.T, ctx context.Context, db *repos) {
reposStore := NewReposStore(db.DB)
repo1, err := reposStore.Create(ctx, 1, CreateRepoOptions{Name: "repo1"})
require.NoError(t, err)
@ -344,9 +327,7 @@ func reposWatch(t *testing.T, db *repos) {
assert.Equal(t, 2, repo1.NumWatches) // The owner is watching the repo by default.
}
func reposHasForkedBy(t *testing.T, db *repos) {
ctx := context.Background()
func reposHasForkedBy(t *testing.T, ctx context.Context, db *repos) {
has := db.HasForkedBy(ctx, 1, 2)
assert.False(t, has)

View File

@ -67,6 +67,7 @@ func TestTwoFactors(t *testing.T) {
}
t.Parallel()
ctx := context.Background()
tables := []any{new(TwoFactor), new(TwoFactorRecoveryCode)}
db := &twoFactors{
DB: dbtest.NewDB(t, "twoFactors", tables...),
@ -74,7 +75,7 @@ func TestTwoFactors(t *testing.T) {
for _, tc := range []struct {
name string
test func(t *testing.T, db *twoFactors)
test func(t *testing.T, ctx context.Context, db *twoFactors)
}{
{"Create", twoFactorsCreate},
{"GetByUserID", twoFactorsGetByUserID},
@ -85,7 +86,7 @@ func TestTwoFactors(t *testing.T) {
err := clearTables(t, db.DB, tables...)
require.NoError(t, err)
})
tc.test(t, db)
tc.test(t, ctx, db)
})
if t.Failed() {
break
@ -93,9 +94,7 @@ func TestTwoFactors(t *testing.T) {
}
}
func twoFactorsCreate(t *testing.T, db *twoFactors) {
ctx := context.Background()
func twoFactorsCreate(t *testing.T, ctx context.Context, db *twoFactors) {
// Create a 2FA token
err := db.Create(ctx, 1, "secure-key", "secure-secret")
require.NoError(t, err)
@ -112,9 +111,7 @@ func twoFactorsCreate(t *testing.T, db *twoFactors) {
assert.Equal(t, int64(10), count)
}
func twoFactorsGetByUserID(t *testing.T, db *twoFactors) {
ctx := context.Background()
func twoFactorsGetByUserID(t *testing.T, ctx context.Context, db *twoFactors) {
// Create a 2FA token for user 1
err := db.Create(ctx, 1, "secure-key", "secure-secret")
require.NoError(t, err)
@ -129,9 +126,7 @@ func twoFactorsGetByUserID(t *testing.T, db *twoFactors) {
assert.Equal(t, wantErr, err)
}
func twoFactorsIsEnabled(t *testing.T, db *twoFactors) {
ctx := context.Background()
func twoFactorsIsEnabled(t *testing.T, ctx context.Context, db *twoFactors) {
// Create a 2FA token for user 1
err := db.Create(ctx, 1, "secure-key", "secure-secret")
require.NoError(t, err)

View File

@ -84,6 +84,7 @@ func TestUsers(t *testing.T) {
}
t.Parallel()
ctx := context.Background()
tables := []any{
new(User), new(EmailAddress), new(Repository), new(Follow), new(PullRequest), new(PublicKey), new(OrgUser),
new(Watch), new(Star), new(Issue), new(AccessToken), new(Collaboration), new(Action), new(IssueUser),
@ -95,7 +96,7 @@ func TestUsers(t *testing.T) {
for _, tc := range []struct {
name string
test func(t *testing.T, db *users)
test func(t *testing.T, ctx context.Context, db *users)
}{
{"Authenticate", usersAuthenticate},
{"ChangeUsername", usersChangeUsername},
@ -131,7 +132,7 @@ func TestUsers(t *testing.T) {
err := clearTables(t, db.DB, tables...)
require.NoError(t, err)
})
tc.test(t, db)
tc.test(t, ctx, db)
})
if t.Failed() {
break
@ -139,9 +140,7 @@ func TestUsers(t *testing.T) {
}
}
func usersAuthenticate(t *testing.T, db *users) {
ctx := context.Background()
func usersAuthenticate(t *testing.T, ctx context.Context, db *users) {
password := "pa$$word"
alice, err := db.Create(ctx, "alice", "alice@example.com",
CreateUserOptions{
@ -236,9 +235,7 @@ func usersAuthenticate(t *testing.T, db *users) {
})
}
func usersChangeUsername(t *testing.T, db *users) {
ctx := context.Background()
func usersChangeUsername(t *testing.T, ctx context.Context, db *users) {
alice, err := db.Create(
ctx,
"alice",
@ -368,9 +365,7 @@ func usersChangeUsername(t *testing.T, db *users) {
assert.Equal(t, strings.ToUpper(newUsername), alice.Name)
}
func usersCount(t *testing.T, db *users) {
ctx := context.Background()
func usersCount(t *testing.T, ctx context.Context, db *users) {
// Has no user initially
got := db.Count(ctx)
assert.Equal(t, int64(0), got)
@ -393,9 +388,7 @@ func usersCount(t *testing.T, db *users) {
assert.Equal(t, int64(1), got)
}
func usersCreate(t *testing.T, db *users) {
ctx := context.Background()
func usersCreate(t *testing.T, ctx context.Context, db *users) {
alice, err := db.Create(
ctx,
"alice",
@ -443,9 +436,7 @@ func usersCreate(t *testing.T, db *users) {
assert.Equal(t, db.NowFunc().Format(time.RFC3339), user.Updated.UTC().Format(time.RFC3339))
}
func usersDeleteCustomAvatar(t *testing.T, db *users) {
ctx := context.Background()
func usersDeleteCustomAvatar(t *testing.T, ctx context.Context, db *users) {
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
@ -479,8 +470,7 @@ func usersDeleteCustomAvatar(t *testing.T, db *users) {
assert.False(t, alice.UseCustomAvatar)
}
func usersDeleteByID(t *testing.T, db *users) {
ctx := context.Background()
func usersDeleteByID(t *testing.T, ctx context.Context, db *users) {
reposStore := NewReposStore(db.DB)
t.Run("user still has repository ownership", func(t *testing.T) {
@ -690,9 +680,7 @@ func usersDeleteByID(t *testing.T, db *users) {
assert.Equal(t, wantErr, err)
}
func usersDeleteInactivated(t *testing.T, db *users) {
ctx := context.Background()
func usersDeleteInactivated(t *testing.T, ctx context.Context, db *users) {
// User with repository ownership should be skipped
alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{})
require.NoError(t, err)
@ -738,9 +726,7 @@ func usersDeleteInactivated(t *testing.T, db *users) {
require.Len(t, users, 3)
}
func usersGetByEmail(t *testing.T, db *users) {
ctx := context.Background()
func usersGetByEmail(t *testing.T, ctx context.Context, db *users) {
t.Run("empty email", func(t *testing.T) {
_, err := db.GetByEmail(ctx, "")
wantErr := ErrUserNotExist{args: errutil.Args{"email": ""}}
@ -801,9 +787,7 @@ func usersGetByEmail(t *testing.T, db *users) {
})
}
func usersGetByID(t *testing.T, db *users) {
ctx := context.Background()
func usersGetByID(t *testing.T, ctx context.Context, db *users) {
alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{})
require.NoError(t, err)
@ -816,9 +800,7 @@ func usersGetByID(t *testing.T, db *users) {
assert.Equal(t, wantErr, err)
}
func usersGetByUsername(t *testing.T, db *users) {
ctx := context.Background()
func usersGetByUsername(t *testing.T, ctx context.Context, db *users) {
alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{})
require.NoError(t, err)
@ -831,9 +813,7 @@ func usersGetByUsername(t *testing.T, db *users) {
assert.Equal(t, wantErr, err)
}
func usersGetByKeyID(t *testing.T, db *users) {
ctx := context.Background()
func usersGetByKeyID(t *testing.T, ctx context.Context, db *users) {
alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{})
require.NoError(t, err)
@ -858,9 +838,7 @@ func usersGetByKeyID(t *testing.T, db *users) {
assert.Equal(t, wantErr, err)
}
func usersGetMailableEmailsByUsernames(t *testing.T, db *users) {
ctx := context.Background()
func usersGetMailableEmailsByUsernames(t *testing.T, ctx context.Context, db *users) {
alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{})
require.NoError(t, err)
bob, err := db.Create(ctx, "bob", "bob@exmaple.com", CreateUserOptions{Activated: true})
@ -874,9 +852,7 @@ func usersGetMailableEmailsByUsernames(t *testing.T, db *users) {
assert.Equal(t, want, got)
}
func usersIsUsernameUsed(t *testing.T, db *users) {
ctx := context.Background()
func usersIsUsernameUsed(t *testing.T, ctx context.Context, db *users) {
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
@ -926,9 +902,7 @@ func usersIsUsernameUsed(t *testing.T, db *users) {
}
}
func usersList(t *testing.T, db *users) {
ctx := context.Background()
func usersList(t *testing.T, ctx context.Context, db *users) {
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
bob, err := db.Create(ctx, "bob", "bob@example.com", CreateUserOptions{})
@ -961,9 +935,7 @@ func usersList(t *testing.T, db *users) {
assert.Equal(t, bob.ID, got[1].ID)
}
func usersListFollowers(t *testing.T, db *users) {
ctx := context.Background()
func usersListFollowers(t *testing.T, ctx context.Context, db *users) {
john, err := db.Create(ctx, "john", "john@example.com", CreateUserOptions{})
require.NoError(t, err)
@ -994,9 +966,7 @@ func usersListFollowers(t *testing.T, db *users) {
assert.Equal(t, alice.ID, got[0].ID)
}
func usersListFollowings(t *testing.T, db *users) {
ctx := context.Background()
func usersListFollowings(t *testing.T, ctx context.Context, db *users) {
john, err := db.Create(ctx, "john", "john@example.com", CreateUserOptions{})
require.NoError(t, err)
@ -1027,9 +997,7 @@ func usersListFollowings(t *testing.T, db *users) {
assert.Equal(t, alice.ID, got[0].ID)
}
func usersSearchByName(t *testing.T, db *users) {
ctx := context.Background()
func usersSearchByName(t *testing.T, ctx context.Context, db *users) {
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{FullName: "Alice Jordan"})
require.NoError(t, err)
bob, err := db.Create(ctx, "bob", "bob@example.com", CreateUserOptions{FullName: "Bob Jordan"})
@ -1067,9 +1035,7 @@ func usersSearchByName(t *testing.T, db *users) {
})
}
func usersUpdate(t *testing.T, db *users) {
ctx := context.Background()
func usersUpdate(t *testing.T, ctx context.Context, db *users) {
const oldPassword = "Password"
alice, err := db.Create(
ctx,
@ -1182,9 +1148,7 @@ func usersUpdate(t *testing.T, db *users) {
assertValues()
}
func usersUseCustomAvatar(t *testing.T, db *users) {
ctx := context.Background()
func usersUseCustomAvatar(t *testing.T, ctx context.Context, db *users) {
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
@ -1222,9 +1186,7 @@ func TestIsUsernameAllowed(t *testing.T) {
}
}
func usersAddEmail(t *testing.T, db *users) {
ctx := context.Background()
func usersAddEmail(t *testing.T, ctx context.Context, db *users) {
t.Run("multiple users can add the same unverified email", func(t *testing.T) {
alice, err := db.Create(ctx, "alice", "unverified@example.com", CreateUserOptions{})
require.NoError(t, err)
@ -1241,9 +1203,7 @@ func usersAddEmail(t *testing.T, db *users) {
})
}
func usersGetEmail(t *testing.T, db *users) {
ctx := context.Background()
func usersGetEmail(t *testing.T, ctx context.Context, db *users) {
const testUserID = 1
const testEmail = "alice@example.com"
_, err := db.GetEmail(ctx, testUserID, testEmail, false)
@ -1275,9 +1235,7 @@ func usersGetEmail(t *testing.T, db *users) {
assert.Equal(t, testEmail, got.Email)
}
func usersListEmails(t *testing.T, db *users) {
ctx := context.Background()
func usersListEmails(t *testing.T, ctx context.Context, db *users) {
t.Run("list emails with primary email", func(t *testing.T) {
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
@ -1313,9 +1271,7 @@ func usersListEmails(t *testing.T, db *users) {
})
}
func usersMarkEmailActivated(t *testing.T, db *users) {
ctx := context.Background()
func usersMarkEmailActivated(t *testing.T, ctx context.Context, db *users) {
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
@ -1333,8 +1289,7 @@ func usersMarkEmailActivated(t *testing.T, db *users) {
assert.NotEqual(t, alice.Rands, gotAlice.Rands)
}
func usersMarkEmailPrimary(t *testing.T, db *users) {
ctx := context.Background()
func usersMarkEmailPrimary(t *testing.T, ctx context.Context, db *users) {
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
err = db.AddEmail(ctx, alice.ID, "alice2@example.com", false)
@ -1360,8 +1315,7 @@ func usersMarkEmailPrimary(t *testing.T, db *users) {
assert.False(t, gotEmail.IsActivated)
}
func usersDeleteEmail(t *testing.T, db *users) {
ctx := context.Background()
func usersDeleteEmail(t *testing.T, ctx context.Context, db *users) {
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
@ -1377,9 +1331,7 @@ func usersDeleteEmail(t *testing.T, db *users) {
require.Equal(t, want, got)
}
func usersFollow(t *testing.T, db *users) {
ctx := context.Background()
func usersFollow(t *testing.T, ctx context.Context, db *users) {
usersStore := NewUsersStore(db.DB)
alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
@ -1402,9 +1354,7 @@ func usersFollow(t *testing.T, db *users) {
assert.Equal(t, 1, bob.NumFollowers)
}
func usersIsFollowing(t *testing.T, db *users) {
ctx := context.Background()
func usersIsFollowing(t *testing.T, ctx context.Context, db *users) {
usersStore := NewUsersStore(db.DB)
alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
@ -1425,9 +1375,7 @@ func usersIsFollowing(t *testing.T, db *users) {
assert.False(t, got)
}
func usersUnfollow(t *testing.T, db *users) {
ctx := context.Background()
func usersUnfollow(t *testing.T, ctx context.Context, db *users) {
usersStore := NewUsersStore(db.DB)
alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)