diff --git a/internal/context/go_get.go b/internal/context/go_get.go index 5a59e1b58..06417ebfb 100644 --- a/internal/context/go_get.go +++ b/internal/context/go_get.go @@ -28,7 +28,7 @@ func ServeGoGet() macaron.Handler { owner, err := db.Users.GetByUsername(c.Req.Context(), ownerName) if err == nil { - repo, err := db.Repos.GetByName(owner.ID, repoName) + repo, err := db.Repos.GetByName(c.Req.Context(), owner.ID, repoName) if err == nil && repo.DefaultBranch != "" { branchName = repo.DefaultBranch } diff --git a/internal/db/db.go b/internal/db/db.go index b58d88c86..e67ffde9e 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -73,7 +73,7 @@ func newDSN(opts conf.DatabaseOpts) (dsn string, err error) { case "postgres": host, port := parsePostgreSQLHostPort(opts.Host) - dsn = fmt.Sprintf("user='%s' password='%s' host='%s' port='%s' dbname='%s' sslmode='%s' search_path='%s'", + dsn = fmt.Sprintf("user='%s' password='%s' host='%s' port='%s' dbname='%s' sslmode='%s' search_path='%s' application_name='gogs'", opts.User, opts.Password, host, port, opts.Name, opts.SSLMode, opts.Schema) case "mssql": diff --git a/internal/db/db_test.go b/internal/db/db_test.go index 1f4f0109c..65f0c067f 100644 --- a/internal/db/db_test.go +++ b/internal/db/db_test.go @@ -63,9 +63,9 @@ func Test_parseDSN(t *testing.T) { }) tests := []struct { - name string - opts conf.DatabaseOpts - expDSN string + name string + opts conf.DatabaseOpts + wantDSN string }{ { name: "mysql: unix", @@ -76,7 +76,7 @@ func Test_parseDSN(t *testing.T) { User: "gogs", Password: "pa$$word", }, - expDSN: "gogs:pa$$word@unix(/tmp/mysql.sock)/gogs?charset=utf8mb4&parseTime=true", + wantDSN: "gogs:pa$$word@unix(/tmp/mysql.sock)/gogs?charset=utf8mb4&parseTime=true", }, { name: "mysql: tcp", @@ -87,7 +87,7 @@ func Test_parseDSN(t *testing.T) { User: "gogs", Password: "pa$$word", }, - expDSN: "gogs:pa$$word@tcp(localhost:3306)/gogs?charset=utf8mb4&parseTime=true", + wantDSN: "gogs:pa$$word@tcp(localhost:3306)/gogs?charset=utf8mb4&parseTime=true", }, { @@ -101,7 +101,7 @@ func Test_parseDSN(t *testing.T) { Password: "pa$$word", SSLMode: "disable", }, - expDSN: "user='gogs@local' password='pa$$word' host='/tmp/pg.sock' port='5432' dbname='gogs' sslmode='disable' search_path='test'", + wantDSN: "user='gogs@local' password='pa$$word' host='/tmp/pg.sock' port='5432' dbname='gogs' sslmode='disable' search_path='test' application_name='gogs'", }, { name: "postgres: tcp", @@ -114,7 +114,7 @@ func Test_parseDSN(t *testing.T) { Password: "pa$$word", SSLMode: "disable", }, - expDSN: "user='gogs@local' password='pa$$word' host='127.0.0.1' port='5432' dbname='gogs' sslmode='disable' search_path='test'", + wantDSN: "user='gogs@local' password='pa$$word' host='127.0.0.1' port='5432' dbname='gogs' sslmode='disable' search_path='test' application_name='gogs'", }, { @@ -126,7 +126,7 @@ func Test_parseDSN(t *testing.T) { User: "gogs@local", Password: "pa$$word", }, - expDSN: "server=127.0.0.1; port=1433; database=gogs; user id=gogs@local; password=pa$$word;", + wantDSN: "server=127.0.0.1; port=1433; database=gogs; user id=gogs@local; password=pa$$word;", }, { @@ -135,7 +135,7 @@ func Test_parseDSN(t *testing.T) { Type: "sqlite3", Path: "/tmp/gogs.db", }, - expDSN: "file:/tmp/gogs.db?cache=shared&mode=rwc", + wantDSN: "file:/tmp/gogs.db?cache=shared&mode=rwc", }, } for _, test := range tests { @@ -144,7 +144,7 @@ func Test_parseDSN(t *testing.T) { if err != nil { t.Fatal(err) } - assert.Equal(t, test.expDSN, dsn) + assert.Equal(t, test.wantDSN, dsn) }) } } diff --git a/internal/db/mock_gen.go b/internal/db/mock_gen.go index 8d94112f1..eba5faf79 100644 --- a/internal/db/mock_gen.go +++ b/internal/db/mock_gen.go @@ -8,7 +8,7 @@ import ( "testing" ) -//go:generate go-mockgen -f gogs.io/gogs/internal/db -i AccessTokensStore -i LFSStore -i LoginSourcesStore -i LoginSourceFilesStore -i loginSourceFileStore -i PermsStore -i TwoFactorsStore -i UsersStore -o mocks.go +//go:generate go-mockgen -f gogs.io/gogs/internal/db -i AccessTokensStore -i LFSStore -i LoginSourcesStore -i LoginSourceFilesStore -i loginSourceFileStore -i PermsStore -i ReposStore -i TwoFactorsStore -i UsersStore -o mocks.go func SetMockAccessTokensStore(t *testing.T, mock AccessTokensStore) { before := AccessTokens @@ -42,16 +42,6 @@ func SetMockPermsStore(t *testing.T, mock PermsStore) { }) } -var _ ReposStore = (*MockReposStore)(nil) - -type MockReposStore struct { - MockGetByName func(ownerID int64, name string) (*Repository, error) -} - -func (m *MockReposStore) GetByName(ownerID int64, name string) (*Repository, error) { - return m.MockGetByName(ownerID, name) -} - func SetMockReposStore(t *testing.T, mock ReposStore) { before := Repos Repos = mock diff --git a/internal/db/mocks.go b/internal/db/mocks.go index d3e302a8f..ddf9ee5dc 100644 --- a/internal/db/mocks.go +++ b/internal/db/mocks.go @@ -2371,6 +2371,159 @@ func (c PermsStoreSetRepoPermsFuncCall) Results() []interface{} { return []interface{}{c.Result0} } +// MockReposStore is a mock implementation of the ReposStore interface (from +// the package gogs.io/gogs/internal/db) used for unit testing. +type MockReposStore struct { + // GetByNameFunc is an instance of a mock function object controlling + // the behavior of the method GetByName. + GetByNameFunc *ReposStoreGetByNameFunc +} + +// NewMockReposStore creates a new mock of the ReposStore interface. All +// methods return zero values for all results, unless overwritten. +func NewMockReposStore() *MockReposStore { + return &MockReposStore{ + GetByNameFunc: &ReposStoreGetByNameFunc{ + defaultHook: func(context.Context, int64, string) (r0 *Repository, r1 error) { + return + }, + }, + } +} + +// NewStrictMockReposStore creates a new mock of the ReposStore interface. +// All methods panic on invocation, unless overwritten. +func NewStrictMockReposStore() *MockReposStore { + return &MockReposStore{ + GetByNameFunc: &ReposStoreGetByNameFunc{ + defaultHook: func(context.Context, int64, string) (*Repository, error) { + panic("unexpected invocation of MockReposStore.GetByName") + }, + }, + } +} + +// NewMockReposStoreFrom creates a new mock of the MockReposStore interface. +// All methods delegate to the given implementation, unless overwritten. +func NewMockReposStoreFrom(i ReposStore) *MockReposStore { + return &MockReposStore{ + GetByNameFunc: &ReposStoreGetByNameFunc{ + defaultHook: i.GetByName, + }, + } +} + +// ReposStoreGetByNameFunc describes the behavior when the GetByName method +// of the parent MockReposStore instance is invoked. +type ReposStoreGetByNameFunc struct { + defaultHook func(context.Context, int64, string) (*Repository, error) + hooks []func(context.Context, int64, string) (*Repository, error) + history []ReposStoreGetByNameFuncCall + mutex sync.Mutex +} + +// GetByName delegates to the next hook function in the queue and stores the +// parameter and result values of this invocation. +func (m *MockReposStore) GetByName(v0 context.Context, v1 int64, v2 string) (*Repository, error) { + r0, r1 := m.GetByNameFunc.nextHook()(v0, v1, v2) + m.GetByNameFunc.appendCall(ReposStoreGetByNameFuncCall{v0, v1, v2, r0, r1}) + return r0, r1 +} + +// SetDefaultHook sets function that is called when the GetByName method of +// the parent MockReposStore instance is invoked and the hook queue is +// empty. +func (f *ReposStoreGetByNameFunc) SetDefaultHook(hook func(context.Context, int64, string) (*Repository, error)) { + f.defaultHook = hook +} + +// PushHook adds a function to the end of hook queue. Each invocation of the +// GetByName method of the parent MockReposStore instance invokes the hook +// at the front of the queue and discards it. After the queue is empty, the +// default hook function is invoked for any future action. +func (f *ReposStoreGetByNameFunc) PushHook(hook func(context.Context, int64, string) (*Repository, error)) { + f.mutex.Lock() + f.hooks = append(f.hooks, hook) + f.mutex.Unlock() +} + +// SetDefaultReturn calls SetDefaultHook with a function that returns the +// given values. +func (f *ReposStoreGetByNameFunc) SetDefaultReturn(r0 *Repository, r1 error) { + f.SetDefaultHook(func(context.Context, int64, string) (*Repository, error) { + return r0, r1 + }) +} + +// PushReturn calls PushHook with a function that returns the given values. +func (f *ReposStoreGetByNameFunc) PushReturn(r0 *Repository, r1 error) { + f.PushHook(func(context.Context, int64, string) (*Repository, error) { + return r0, r1 + }) +} + +func (f *ReposStoreGetByNameFunc) nextHook() func(context.Context, int64, string) (*Repository, error) { + f.mutex.Lock() + defer f.mutex.Unlock() + + if len(f.hooks) == 0 { + return f.defaultHook + } + + hook := f.hooks[0] + f.hooks = f.hooks[1:] + return hook +} + +func (f *ReposStoreGetByNameFunc) appendCall(r0 ReposStoreGetByNameFuncCall) { + f.mutex.Lock() + f.history = append(f.history, r0) + f.mutex.Unlock() +} + +// History returns a sequence of ReposStoreGetByNameFuncCall objects +// describing the invocations of this function. +func (f *ReposStoreGetByNameFunc) History() []ReposStoreGetByNameFuncCall { + f.mutex.Lock() + history := make([]ReposStoreGetByNameFuncCall, len(f.history)) + copy(history, f.history) + f.mutex.Unlock() + + return history +} + +// ReposStoreGetByNameFuncCall is an object that describes an invocation of +// method GetByName on an instance of MockReposStore. +type ReposStoreGetByNameFuncCall struct { + // Arg0 is the value of the 1st argument passed to this method + // invocation. + Arg0 context.Context + // Arg1 is the value of the 2nd argument passed to this method + // invocation. + Arg1 int64 + // Arg2 is the value of the 3rd argument passed to this method + // invocation. + Arg2 string + // Result0 is the value of the 1st result returned from this method + // invocation. + Result0 *Repository + // Result1 is the value of the 2nd result returned from this method + // invocation. + Result1 error +} + +// Args returns an interface slice containing the arguments of this +// invocation. +func (c ReposStoreGetByNameFuncCall) Args() []interface{} { + return []interface{}{c.Arg0, c.Arg1, c.Arg2} +} + +// Results returns an interface slice containing the results of this +// invocation. +func (c ReposStoreGetByNameFuncCall) Results() []interface{} { + return []interface{}{c.Result0, c.Result1} +} + // MockTwoFactorsStore is a mock implementation of the TwoFactorsStore // interface (from the package gogs.io/gogs/internal/db) used for unit // testing. diff --git a/internal/db/repos.go b/internal/db/repos.go index ecdbc0a5b..8b4c5bcee 100644 --- a/internal/db/repos.go +++ b/internal/db/repos.go @@ -5,6 +5,7 @@ package db import ( + "context" "fmt" "strings" "time" @@ -18,14 +19,14 @@ import ( // // NOTE: All methods are sorted in alphabetical order. type ReposStore interface { - // GetByName returns the repository with given owner and name. - // It returns ErrRepoNotExist when not found. - GetByName(ownerID int64, name string) (*Repository, error) + // GetByName returns the repository with given owner and name. It returns + // ErrRepoNotExist when not found. + GetByName(ctx context.Context, ownerID int64, name string) (*Repository, error) } var Repos ReposStore -// NOTE: This is a GORM create hook. +// BeforeCreate implements the GORM create hook. func (r *Repository) BeforeCreate(tx *gorm.DB) error { if r.CreatedUnix == 0 { r.CreatedUnix = tx.NowFunc().Unix() @@ -33,13 +34,13 @@ func (r *Repository) BeforeCreate(tx *gorm.DB) error { return nil } -// NOTE: This is a GORM update hook. +// BeforeUpdate implements the GORM update hook. func (r *Repository) BeforeUpdate(tx *gorm.DB) error { r.UpdatedUnix = tx.NowFunc().Unix() return nil } -// NOTE: This is a GORM query hook. +// AfterFind implements the GORM query hook. func (r *Repository) AfterFind(_ *gorm.DB) error { r.Created = time.Unix(r.CreatedUnix, 0).Local() r.Updated = time.Unix(r.UpdatedUnix, 0).Local() @@ -81,13 +82,13 @@ type createRepoOpts struct { // create creates a new repository record in the database. Fields of "repo" will be updated // in place upon insertion. It returns ErrNameNotAllowed when the repository name is not allowed, // or ErrRepoAlreadyExist when a repository with same name already exists for the owner. -func (db *repos) create(ownerID int64, opts createRepoOpts) (*Repository, error) { +func (db *repos) create(ctx context.Context, ownerID int64, opts createRepoOpts) (*Repository, error) { err := isRepoNameAllowed(opts.Name) if err != nil { return nil, err } - _, err = db.GetByName(ownerID, opts.Name) + _, err = db.GetByName(ctx, ownerID, opts.Name) if err == nil { return nil, ErrRepoAlreadyExist{args: errutil.Args{"ownerID": ownerID, "name": opts.Name}} } else if !IsErrRepoNotExist(err) { @@ -108,7 +109,7 @@ func (db *repos) create(ownerID int64, opts createRepoOpts) (*Repository, error) IsFork: opts.Fork, ForkID: opts.ForkID, } - return repo, db.DB.Create(repo).Error + return repo, db.WithContext(ctx).Create(repo).Error } var _ errutil.NotFound = (*ErrRepoNotExist)(nil) @@ -130,9 +131,12 @@ func (ErrRepoNotExist) NotFound() bool { return true } -func (db *repos) GetByName(ownerID int64, name string) (*Repository, error) { +func (db *repos) GetByName(ctx context.Context, ownerID int64, name string) (*Repository, error) { repo := new(Repository) - err := db.Where("owner_id = ? AND lower_name = ?", ownerID, strings.ToLower(name)).First(repo).Error + err := db.WithContext(ctx). + Where("owner_id = ? AND lower_name = ?", ownerID, strings.ToLower(name)). + First(repo). + Error if err != nil { if err == gorm.ErrRecordNotFound { return nil, ErrRepoNotExist{args: map[string]interface{}{"ownerID": ownerID, "name": name}} diff --git a/internal/db/repos_test.go b/internal/db/repos_test.go index d248f3b05..324825063 100644 --- a/internal/db/repos_test.go +++ b/internal/db/repos_test.go @@ -5,15 +5,17 @@ package db import ( + "context" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gogs.io/gogs/internal/errutil" ) -func Test_repos(t *testing.T) { +func TestRepos(t *testing.T) { if testing.Short() { t.Skip() } @@ -29,15 +31,13 @@ func Test_repos(t *testing.T) { name string test func(*testing.T, *repos) }{ - {"create", test_repos_create}, - {"GetByName", test_repos_GetByName}, + {"create", reposCreate}, + {"GetByName", reposGetByName}, } { t.Run(tc.name, func(t *testing.T) { t.Cleanup(func() { err := clearTables(t, db.DB, tables...) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) }) tc.test(t, db) }) @@ -47,58 +47,63 @@ func Test_repos(t *testing.T) { } } -func test_repos_create(t *testing.T, db *repos) { +func reposCreate(t *testing.T, db *repos) { + ctx := context.Background() + t.Run("name not allowed", func(t *testing.T) { - _, err := db.create(1, createRepoOpts{ - Name: "my.git", - }) - expErr := ErrNameNotAllowed{args: errutil.Args{"reason": "reserved", "pattern": "*.git"}} - assert.Equal(t, expErr, err) + _, err := db.create(ctx, + 1, + createRepoOpts{ + Name: "my.git", + }, + ) + wantErr := ErrNameNotAllowed{args: errutil.Args{"reason": "reserved", "pattern": "*.git"}} + assert.Equal(t, wantErr, err) }) t.Run("already exists", func(t *testing.T) { - _, err := db.create(2, createRepoOpts{ - Name: "repo1", - }) - if err != nil { - t.Fatal(err) - } + _, err := db.create(ctx, 2, + createRepoOpts{ + Name: "repo1", + }, + ) + require.NoError(t, err) - _, err = db.create(2, createRepoOpts{ - Name: "repo1", - }) - expErr := ErrRepoAlreadyExist{args: errutil.Args{"ownerID": int64(2), "name": "repo1"}} - assert.Equal(t, expErr, err) + _, err = db.create(ctx, 2, + createRepoOpts{ + Name: "repo1", + }, + ) + wantErr := ErrRepoAlreadyExist{args: errutil.Args{"ownerID": int64(2), "name": "repo1"}} + assert.Equal(t, wantErr, err) }) - repo, err := db.create(3, createRepoOpts{ - Name: "repo2", - }) - if err != nil { - t.Fatal(err) - } + repo, err := db.create(ctx, 3, + createRepoOpts{ + Name: "repo2", + }, + ) + require.NoError(t, err) - repo, err = db.GetByName(repo.OwnerID, repo.Name) - if err != nil { - t.Fatal(err) - } + repo, err = db.GetByName(ctx, repo.OwnerID, repo.Name) + require.NoError(t, err) assert.Equal(t, db.NowFunc().Format(time.RFC3339), repo.Created.UTC().Format(time.RFC3339)) } -func test_repos_GetByName(t *testing.T, db *repos) { - repo, err := db.create(1, createRepoOpts{ - Name: "repo1", - }) - if err != nil { - t.Fatal(err) - } +func reposGetByName(t *testing.T, db *repos) { + ctx := context.Background() - _, err = db.GetByName(repo.OwnerID, repo.Name) - if err != nil { - t.Fatal(err) - } + repo, err := db.create(ctx, 1, + createRepoOpts{ + Name: "repo1", + }, + ) + require.NoError(t, err) - _, err = db.GetByName(1, "bad_name") - expErr := ErrRepoNotExist{args: errutil.Args{"ownerID": int64(1), "name": "bad_name"}} - assert.Equal(t, expErr, err) + _, err = db.GetByName(ctx, repo.OwnerID, repo.Name) + require.NoError(t, err) + + _, err = db.GetByName(ctx, 1, "bad_name") + wantErr := ErrRepoNotExist{args: errutil.Args{"ownerID": int64(1), "name": "bad_name"}} + assert.Equal(t, wantErr, err) } diff --git a/internal/route/api/v1/api.go b/internal/route/api/v1/api.go index 9d256e3e3..a517c6df8 100644 --- a/internal/route/api/v1/api.go +++ b/internal/route/api/v1/api.go @@ -45,7 +45,7 @@ func repoAssignment() macaron.Handler { } c.Repo.Owner = owner - repo, err := db.Repos.GetByName(owner.ID, reponame) + repo, err := db.Repos.GetByName(c.Req.Context(), owner.ID, reponame) if err != nil { c.NotFoundOrError(err, "get repository by name") return diff --git a/internal/route/lfs/route.go b/internal/route/lfs/route.go index 94c42fea0..c00f7374b 100644 --- a/internal/route/lfs/route.go +++ b/internal/route/lfs/route.go @@ -119,7 +119,7 @@ func authorize(mode db.AccessMode) macaron.Handler { return } - repo, err := db.Repos.GetByName(owner.ID, reponame) + repo, err := db.Repos.GetByName(c.Req.Context(), owner.ID, reponame) if err != nil { if db.IsErrRepoNotExist(err) { c.Status(http.StatusNotFound) diff --git a/internal/route/lfs/route_test.go b/internal/route/lfs/route_test.go index ee668fc63..cc6fb1beb 100644 --- a/internal/route/lfs/route_test.go +++ b/internal/route/lfs/route_test.go @@ -167,7 +167,7 @@ func Test_authorize(t *testing.T) { name string authroize macaron.Handler mockUsersStore func() db.UsersStore - mockReposStore *db.MockReposStore + mockReposStore func() db.ReposStore mockPermsStore func() db.PermsStore expStatusCode int expBody string @@ -192,10 +192,10 @@ func Test_authorize(t *testing.T) { }) return mock }, - mockReposStore: &db.MockReposStore{ - MockGetByName: func(ownerID int64, name string) (*db.Repository, error) { - return nil, db.ErrRepoNotExist{} - }, + mockReposStore: func() db.ReposStore { + mock := db.NewMockReposStore() + mock.GetByNameFunc.SetDefaultReturn(nil, db.ErrRepoNotExist{}) + return mock }, expStatusCode: http.StatusNotFound, }, @@ -209,10 +209,12 @@ func Test_authorize(t *testing.T) { }) return mock }, - mockReposStore: &db.MockReposStore{ - MockGetByName: func(ownerID int64, name string) (*db.Repository, error) { + mockReposStore: func() db.ReposStore { + mock := db.NewMockReposStore() + mock.GetByNameFunc.SetDefaultHook(func(ctx context.Context, ownerID int64, name string) (*db.Repository, error) { return &db.Repository{Name: name}, nil - }, + }) + return mock }, mockPermsStore: func() db.PermsStore { mock := db.NewMockPermsStore() @@ -234,10 +236,12 @@ func Test_authorize(t *testing.T) { }) return mock }, - mockReposStore: &db.MockReposStore{ - MockGetByName: func(ownerID int64, name string) (*db.Repository, error) { + mockReposStore: func() db.ReposStore { + mock := db.NewMockReposStore() + mock.GetByNameFunc.SetDefaultHook(func(ctx context.Context, ownerID int64, name string) (*db.Repository, error) { return &db.Repository{Name: name}, nil - }, + }) + return mock }, mockPermsStore: func() db.PermsStore { mock := db.NewMockPermsStore() @@ -255,7 +259,9 @@ func Test_authorize(t *testing.T) { if test.mockUsersStore != nil { db.SetMockUsersStore(t, test.mockUsersStore()) } - db.SetMockReposStore(t, test.mockReposStore) + if test.mockReposStore != nil { + db.SetMockReposStore(t, test.mockReposStore()) + } if test.mockPermsStore != nil { db.SetMockPermsStore(t, test.mockPermsStore()) } diff --git a/internal/route/repo/http.go b/internal/route/repo/http.go index 888ad4d86..af51a076b 100644 --- a/internal/route/repo/http.go +++ b/internal/route/repo/http.go @@ -76,7 +76,7 @@ func HTTPContexter() macaron.Handler { return } - repo, err := db.Repos.GetByName(owner.ID, repoName) + repo, err := db.Repos.GetByName(c.Req.Context(), owner.ID, repoName) if err != nil { if db.IsErrRepoNotExist(err) { c.Status(http.StatusNotFound) diff --git a/internal/route/repo/tasks.go b/internal/route/repo/tasks.go index c5b555b9c..d92158856 100644 --- a/internal/route/repo/tasks.go +++ b/internal/route/repo/tasks.go @@ -44,7 +44,7 @@ func TriggerTask(c *macaron.Context) { return } - repo, err := db.Repos.GetByName(owner.ID, reponame) + repo, err := db.Repos.GetByName(c.Req.Context(), owner.ID, reponame) if err != nil { if db.IsErrRepoNotExist(err) { c.Error(http.StatusBadRequest, "Repository does not exist")