From 94059f291dc85bfe222ca8d9094fd11334c0dd53 Mon Sep 17 00:00:00 2001 From: Joe Chen Date: Fri, 10 Jun 2022 11:27:06 +0800 Subject: [PATCH] db: use `context` and go-mockgen for `LFSStore` (#7038) --- internal/db/lfs.go | 37 +-- internal/db/lfs_test.go | 72 +++--- internal/db/mock_gen.go | 24 +- internal/db/mocks.go | 419 +++++++++++++++++++++++++++++++ internal/route/lfs/basic.go | 8 +- internal/route/lfs/basic_test.go | 86 +++---- internal/route/lfs/batch.go | 2 +- internal/route/lfs/batch_test.go | 20 +- 8 files changed, 532 insertions(+), 136 deletions(-) diff --git a/internal/db/lfs.go b/internal/db/lfs.go index 43515d8cf..1a34b8023 100644 --- a/internal/db/lfs.go +++ b/internal/db/lfs.go @@ -5,6 +5,7 @@ package db import ( + "context" "fmt" "time" @@ -19,24 +20,24 @@ import ( // NOTE: All methods are sorted in alphabetical order. type LFSStore interface { // CreateObject creates a LFS object record in database. - CreateObject(repoID int64, oid lfsutil.OID, size int64, storage lfsutil.Storage) error - // GetObjectByOID returns the LFS object with given OID. It returns ErrLFSObjectNotExist - // when not found. - GetObjectByOID(repoID int64, oid lfsutil.OID) (*LFSObject, error) - // GetObjectsByOIDs returns LFS objects found within "oids". The returned list could have - // less elements if some oids were not found. - GetObjectsByOIDs(repoID int64, oids ...lfsutil.OID) ([]*LFSObject, error) + CreateObject(ctx context.Context, repoID int64, oid lfsutil.OID, size int64, storage lfsutil.Storage) error + // GetObjectByOID returns the LFS object with given OID. It returns + // ErrLFSObjectNotExist when not found. + GetObjectByOID(ctx context.Context, repoID int64, oid lfsutil.OID) (*LFSObject, error) + // GetObjectsByOIDs returns LFS objects found within "oids". The returned list + // could have less elements if some oids were not found. + GetObjectsByOIDs(ctx context.Context, repoID int64, oids ...lfsutil.OID) ([]*LFSObject, error) } var LFS LFSStore // LFSObject is the relation between an LFS object and a repository. type LFSObject struct { - RepoID int64 `gorm:"PRIMARY_KEY;AUTO_INCREMENT:false"` - OID lfsutil.OID `gorm:"PRIMARY_KEY;COLUMN:oid"` - Size int64 `gorm:"NOT NULL"` - Storage lfsutil.Storage `gorm:"NOT NULL"` - CreatedAt time.Time `gorm:"NOT NULL"` + RepoID int64 `gorm:"primary_key;auto_increment:false"` + OID lfsutil.OID `gorm:"primary_key;column:oid"` + Size int64 `gorm:"not null"` + Storage lfsutil.Storage `gorm:"not null"` + CreatedAt time.Time `gorm:"not null"` } var _ LFSStore = (*lfs)(nil) @@ -45,14 +46,14 @@ type lfs struct { *gorm.DB } -func (db *lfs) CreateObject(repoID int64, oid lfsutil.OID, size int64, storage lfsutil.Storage) error { +func (db *lfs) CreateObject(ctx context.Context, repoID int64, oid lfsutil.OID, size int64, storage lfsutil.Storage) error { object := &LFSObject{ RepoID: repoID, OID: oid, Size: size, Storage: storage, } - return db.DB.Create(object).Error + return db.WithContext(ctx).Create(object).Error } type ErrLFSObjectNotExist struct { @@ -72,9 +73,9 @@ func (ErrLFSObjectNotExist) NotFound() bool { return true } -func (db *lfs) GetObjectByOID(repoID int64, oid lfsutil.OID) (*LFSObject, error) { +func (db *lfs) GetObjectByOID(ctx context.Context, repoID int64, oid lfsutil.OID) (*LFSObject, error) { object := new(LFSObject) - err := db.Where("repo_id = ? AND oid = ?", repoID, oid).First(object).Error + err := db.WithContext(ctx).Where("repo_id = ? AND oid = ?", repoID, oid).First(object).Error if err != nil { if err == gorm.ErrRecordNotFound { return nil, ErrLFSObjectNotExist{args: errutil.Args{"repoID": repoID, "oid": oid}} @@ -84,13 +85,13 @@ func (db *lfs) GetObjectByOID(repoID int64, oid lfsutil.OID) (*LFSObject, error) return object, err } -func (db *lfs) GetObjectsByOIDs(repoID int64, oids ...lfsutil.OID) ([]*LFSObject, error) { +func (db *lfs) GetObjectsByOIDs(ctx context.Context, repoID int64, oids ...lfsutil.OID) ([]*LFSObject, error) { if len(oids) == 0 { return []*LFSObject{}, nil } objects := make([]*LFSObject, 0, len(oids)) - err := db.Where("repo_id = ? AND oid IN (?)", repoID, oids).Find(&objects).Error + err := db.WithContext(ctx).Where("repo_id = ? AND oid IN (?)", repoID, oids).Find(&objects).Error if err != nil && err != gorm.ErrRecordNotFound { return nil, err } diff --git a/internal/db/lfs_test.go b/internal/db/lfs_test.go index 6c7af93a1..e650369c9 100644 --- a/internal/db/lfs_test.go +++ b/internal/db/lfs_test.go @@ -5,16 +5,18 @@ package db import ( + "context" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gogs.io/gogs/internal/errutil" "gogs.io/gogs/internal/lfsutil" ) -func Test_lfs(t *testing.T) { +func TestLFS(t *testing.T) { if testing.Short() { t.Skip() } @@ -30,16 +32,14 @@ func Test_lfs(t *testing.T) { name string test func(*testing.T, *lfs) }{ - {"CreateObject", test_lfs_CreateObject}, - {"GetObjectByOID", test_lfs_GetObjectByOID}, - {"GetObjectsByOIDs", test_lfs_GetObjectsByOIDs}, + {"CreateObject", lfsCreateObject}, + {"GetObjectByOID", lfsGetObjectByOID}, + {"GetObjectsByOIDs", lfsGetObjectsByOIDs}, } { t.Run(tc.name, func(t *testing.T) { t.Cleanup(func() { err := clearTables(t, db.DB, tables...) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) }) tc.test(t, db) }) @@ -49,67 +49,59 @@ func Test_lfs(t *testing.T) { } } -func test_lfs_CreateObject(t *testing.T, db *lfs) { +func lfsCreateObject(t *testing.T, db *lfs) { + ctx := context.Background() + // Create first LFS object repoID := int64(1) oid := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f") - err := db.CreateObject(repoID, oid, 12, lfsutil.StorageLocal) - if err != nil { - t.Fatal(err) - } + err := db.CreateObject(ctx, repoID, oid, 12, lfsutil.StorageLocal) + require.NoError(t, err) // Get it back and check the CreatedAt field - object, err := db.GetObjectByOID(repoID, oid) - if err != nil { - t.Fatal(err) - } + object, err := db.GetObjectByOID(ctx, repoID, oid) + require.NoError(t, err) assert.Equal(t, db.NowFunc().Format(time.RFC3339), object.CreatedAt.UTC().Format(time.RFC3339)) // Try create second LFS object with same oid should fail - err = db.CreateObject(repoID, oid, 12, lfsutil.StorageLocal) + err = db.CreateObject(ctx, repoID, oid, 12, lfsutil.StorageLocal) assert.Error(t, err) } -func test_lfs_GetObjectByOID(t *testing.T, db *lfs) { +func lfsGetObjectByOID(t *testing.T, db *lfs) { + ctx := context.Background() + // Create a LFS object repoID := int64(1) oid := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f") - err := db.CreateObject(repoID, oid, 12, lfsutil.StorageLocal) - if err != nil { - t.Fatal(err) - } + err := db.CreateObject(ctx, repoID, oid, 12, lfsutil.StorageLocal) + require.NoError(t, err) // We should be able to get it back - _, err = db.GetObjectByOID(repoID, oid) - if err != nil { - t.Fatal(err) - } + _, err = db.GetObjectByOID(ctx, repoID, oid) + require.NoError(t, err) // Try to get a non-existent object - _, err = db.GetObjectByOID(repoID, "bad_oid") + _, err = db.GetObjectByOID(ctx, repoID, "bad_oid") expErr := ErrLFSObjectNotExist{args: errutil.Args{"repoID": repoID, "oid": lfsutil.OID("bad_oid")}} assert.Equal(t, expErr, err) } -func test_lfs_GetObjectsByOIDs(t *testing.T, db *lfs) { +func lfsGetObjectsByOIDs(t *testing.T, db *lfs) { + ctx := context.Background() + // Create two LFS objects repoID := int64(1) oid1 := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f") oid2 := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64g") - err := db.CreateObject(repoID, oid1, 12, lfsutil.StorageLocal) - if err != nil { - t.Fatal(err) - } - err = db.CreateObject(repoID, oid2, 12, lfsutil.StorageLocal) - if err != nil { - t.Fatal(err) - } + err := db.CreateObject(ctx, repoID, oid1, 12, lfsutil.StorageLocal) + require.NoError(t, err) + err = db.CreateObject(ctx, repoID, oid2, 12, lfsutil.StorageLocal) + require.NoError(t, err) // We should be able to get them back and ignore non-existent ones - objects, err := db.GetObjectsByOIDs(repoID, oid1, oid2, "bad_oid") - if err != nil { - t.Fatal(err) - } + objects, err := db.GetObjectsByOIDs(ctx, repoID, oid1, oid2, "bad_oid") + require.NoError(t, err) assert.Equal(t, 2, len(objects), "number of objects") assert.Equal(t, repoID, objects[0].RepoID) diff --git a/internal/db/mock_gen.go b/internal/db/mock_gen.go index ce347a637..b3821287f 100644 --- a/internal/db/mock_gen.go +++ b/internal/db/mock_gen.go @@ -6,11 +6,9 @@ package db import ( "testing" - - "gogs.io/gogs/internal/lfsutil" ) -//go:generate go-mockgen -f gogs.io/gogs/internal/db -i AccessTokensStore -i PermsStore -o mocks.go +//go:generate go-mockgen -f gogs.io/gogs/internal/db -i AccessTokensStore -i LFSStore -i PermsStore -o mocks.go func SetMockAccessTokensStore(t *testing.T, mock AccessTokensStore) { before := AccessTokens @@ -20,26 +18,6 @@ func SetMockAccessTokensStore(t *testing.T, mock AccessTokensStore) { }) } -var _ LFSStore = (*MockLFSStore)(nil) - -type MockLFSStore struct { - MockCreateObject func(repoID int64, oid lfsutil.OID, size int64, storage lfsutil.Storage) error - MockGetObjectByOID func(repoID int64, oid lfsutil.OID) (*LFSObject, error) - MockGetObjectsByOIDs func(repoID int64, oids ...lfsutil.OID) ([]*LFSObject, error) -} - -func (m *MockLFSStore) CreateObject(repoID int64, oid lfsutil.OID, size int64, storage lfsutil.Storage) error { - return m.MockCreateObject(repoID, oid, size, storage) -} - -func (m *MockLFSStore) GetObjectByOID(repoID int64, oid lfsutil.OID) (*LFSObject, error) { - return m.MockGetObjectByOID(repoID, oid) -} - -func (m *MockLFSStore) GetObjectsByOIDs(repoID int64, oids ...lfsutil.OID) ([]*LFSObject, error) { - return m.MockGetObjectsByOIDs(repoID, oids...) -} - func SetMockLFSStore(t *testing.T, mock LFSStore) { before := LFS LFS = mock diff --git a/internal/db/mocks.go b/internal/db/mocks.go index e969d83fe..e5a24d065 100644 --- a/internal/db/mocks.go +++ b/internal/db/mocks.go @@ -5,6 +5,8 @@ package db import ( "context" "sync" + + lfsutil "gogs.io/gogs/internal/lfsutil" ) // MockAccessTokensStore is a mock implementation of the AccessTokensStore @@ -658,6 +660,423 @@ func (c AccessTokensStoreTouchFuncCall) Results() []interface{} { return []interface{}{c.Result0} } +// MockLFSStore is a mock implementation of the LFSStore interface (from the +// package gogs.io/gogs/internal/db) used for unit testing. +type MockLFSStore struct { + // CreateObjectFunc is an instance of a mock function object controlling + // the behavior of the method CreateObject. + CreateObjectFunc *LFSStoreCreateObjectFunc + // GetObjectByOIDFunc is an instance of a mock function object + // controlling the behavior of the method GetObjectByOID. + GetObjectByOIDFunc *LFSStoreGetObjectByOIDFunc + // GetObjectsByOIDsFunc is an instance of a mock function object + // controlling the behavior of the method GetObjectsByOIDs. + GetObjectsByOIDsFunc *LFSStoreGetObjectsByOIDsFunc +} + +// NewMockLFSStore creates a new mock of the LFSStore interface. All methods +// return zero values for all results, unless overwritten. +func NewMockLFSStore() *MockLFSStore { + return &MockLFSStore{ + CreateObjectFunc: &LFSStoreCreateObjectFunc{ + defaultHook: func(context.Context, int64, lfsutil.OID, int64, lfsutil.Storage) (r0 error) { + return + }, + }, + GetObjectByOIDFunc: &LFSStoreGetObjectByOIDFunc{ + defaultHook: func(context.Context, int64, lfsutil.OID) (r0 *LFSObject, r1 error) { + return + }, + }, + GetObjectsByOIDsFunc: &LFSStoreGetObjectsByOIDsFunc{ + defaultHook: func(context.Context, int64, ...lfsutil.OID) (r0 []*LFSObject, r1 error) { + return + }, + }, + } +} + +// NewStrictMockLFSStore creates a new mock of the LFSStore interface. All +// methods panic on invocation, unless overwritten. +func NewStrictMockLFSStore() *MockLFSStore { + return &MockLFSStore{ + CreateObjectFunc: &LFSStoreCreateObjectFunc{ + defaultHook: func(context.Context, int64, lfsutil.OID, int64, lfsutil.Storage) error { + panic("unexpected invocation of MockLFSStore.CreateObject") + }, + }, + GetObjectByOIDFunc: &LFSStoreGetObjectByOIDFunc{ + defaultHook: func(context.Context, int64, lfsutil.OID) (*LFSObject, error) { + panic("unexpected invocation of MockLFSStore.GetObjectByOID") + }, + }, + GetObjectsByOIDsFunc: &LFSStoreGetObjectsByOIDsFunc{ + defaultHook: func(context.Context, int64, ...lfsutil.OID) ([]*LFSObject, error) { + panic("unexpected invocation of MockLFSStore.GetObjectsByOIDs") + }, + }, + } +} + +// NewMockLFSStoreFrom creates a new mock of the MockLFSStore interface. All +// methods delegate to the given implementation, unless overwritten. +func NewMockLFSStoreFrom(i LFSStore) *MockLFSStore { + return &MockLFSStore{ + CreateObjectFunc: &LFSStoreCreateObjectFunc{ + defaultHook: i.CreateObject, + }, + GetObjectByOIDFunc: &LFSStoreGetObjectByOIDFunc{ + defaultHook: i.GetObjectByOID, + }, + GetObjectsByOIDsFunc: &LFSStoreGetObjectsByOIDsFunc{ + defaultHook: i.GetObjectsByOIDs, + }, + } +} + +// LFSStoreCreateObjectFunc describes the behavior when the CreateObject +// method of the parent MockLFSStore instance is invoked. +type LFSStoreCreateObjectFunc struct { + defaultHook func(context.Context, int64, lfsutil.OID, int64, lfsutil.Storage) error + hooks []func(context.Context, int64, lfsutil.OID, int64, lfsutil.Storage) error + history []LFSStoreCreateObjectFuncCall + mutex sync.Mutex +} + +// CreateObject delegates to the next hook function in the queue and stores +// the parameter and result values of this invocation. +func (m *MockLFSStore) CreateObject(v0 context.Context, v1 int64, v2 lfsutil.OID, v3 int64, v4 lfsutil.Storage) error { + r0 := m.CreateObjectFunc.nextHook()(v0, v1, v2, v3, v4) + m.CreateObjectFunc.appendCall(LFSStoreCreateObjectFuncCall{v0, v1, v2, v3, v4, r0}) + return r0 +} + +// SetDefaultHook sets function that is called when the CreateObject method +// of the parent MockLFSStore instance is invoked and the hook queue is +// empty. +func (f *LFSStoreCreateObjectFunc) SetDefaultHook(hook func(context.Context, int64, lfsutil.OID, int64, lfsutil.Storage) error) { + f.defaultHook = hook +} + +// PushHook adds a function to the end of hook queue. Each invocation of the +// CreateObject method of the parent MockLFSStore instance invokes the hook +// at the front of the queue and discards it. After the queue is empty, the +// default hook function is invoked for any future action. +func (f *LFSStoreCreateObjectFunc) PushHook(hook func(context.Context, int64, lfsutil.OID, int64, lfsutil.Storage) error) { + f.mutex.Lock() + f.hooks = append(f.hooks, hook) + f.mutex.Unlock() +} + +// SetDefaultReturn calls SetDefaultHook with a function that returns the +// given values. +func (f *LFSStoreCreateObjectFunc) SetDefaultReturn(r0 error) { + f.SetDefaultHook(func(context.Context, int64, lfsutil.OID, int64, lfsutil.Storage) error { + return r0 + }) +} + +// PushReturn calls PushHook with a function that returns the given values. +func (f *LFSStoreCreateObjectFunc) PushReturn(r0 error) { + f.PushHook(func(context.Context, int64, lfsutil.OID, int64, lfsutil.Storage) error { + return r0 + }) +} + +func (f *LFSStoreCreateObjectFunc) nextHook() func(context.Context, int64, lfsutil.OID, int64, lfsutil.Storage) error { + f.mutex.Lock() + defer f.mutex.Unlock() + + if len(f.hooks) == 0 { + return f.defaultHook + } + + hook := f.hooks[0] + f.hooks = f.hooks[1:] + return hook +} + +func (f *LFSStoreCreateObjectFunc) appendCall(r0 LFSStoreCreateObjectFuncCall) { + f.mutex.Lock() + f.history = append(f.history, r0) + f.mutex.Unlock() +} + +// History returns a sequence of LFSStoreCreateObjectFuncCall objects +// describing the invocations of this function. +func (f *LFSStoreCreateObjectFunc) History() []LFSStoreCreateObjectFuncCall { + f.mutex.Lock() + history := make([]LFSStoreCreateObjectFuncCall, len(f.history)) + copy(history, f.history) + f.mutex.Unlock() + + return history +} + +// LFSStoreCreateObjectFuncCall is an object that describes an invocation of +// method CreateObject on an instance of MockLFSStore. +type LFSStoreCreateObjectFuncCall struct { + // Arg0 is the value of the 1st argument passed to this method + // invocation. + Arg0 context.Context + // Arg1 is the value of the 2nd argument passed to this method + // invocation. + Arg1 int64 + // Arg2 is the value of the 3rd argument passed to this method + // invocation. + Arg2 lfsutil.OID + // Arg3 is the value of the 4th argument passed to this method + // invocation. + Arg3 int64 + // Arg4 is the value of the 5th argument passed to this method + // invocation. + Arg4 lfsutil.Storage + // Result0 is the value of the 1st result returned from this method + // invocation. + Result0 error +} + +// Args returns an interface slice containing the arguments of this +// invocation. +func (c LFSStoreCreateObjectFuncCall) Args() []interface{} { + return []interface{}{c.Arg0, c.Arg1, c.Arg2, c.Arg3, c.Arg4} +} + +// Results returns an interface slice containing the results of this +// invocation. +func (c LFSStoreCreateObjectFuncCall) Results() []interface{} { + return []interface{}{c.Result0} +} + +// LFSStoreGetObjectByOIDFunc describes the behavior when the GetObjectByOID +// method of the parent MockLFSStore instance is invoked. +type LFSStoreGetObjectByOIDFunc struct { + defaultHook func(context.Context, int64, lfsutil.OID) (*LFSObject, error) + hooks []func(context.Context, int64, lfsutil.OID) (*LFSObject, error) + history []LFSStoreGetObjectByOIDFuncCall + mutex sync.Mutex +} + +// GetObjectByOID delegates to the next hook function in the queue and +// stores the parameter and result values of this invocation. +func (m *MockLFSStore) GetObjectByOID(v0 context.Context, v1 int64, v2 lfsutil.OID) (*LFSObject, error) { + r0, r1 := m.GetObjectByOIDFunc.nextHook()(v0, v1, v2) + m.GetObjectByOIDFunc.appendCall(LFSStoreGetObjectByOIDFuncCall{v0, v1, v2, r0, r1}) + return r0, r1 +} + +// SetDefaultHook sets function that is called when the GetObjectByOID +// method of the parent MockLFSStore instance is invoked and the hook queue +// is empty. +func (f *LFSStoreGetObjectByOIDFunc) SetDefaultHook(hook func(context.Context, int64, lfsutil.OID) (*LFSObject, error)) { + f.defaultHook = hook +} + +// PushHook adds a function to the end of hook queue. Each invocation of the +// GetObjectByOID method of the parent MockLFSStore instance invokes the +// hook at the front of the queue and discards it. After the queue is empty, +// the default hook function is invoked for any future action. +func (f *LFSStoreGetObjectByOIDFunc) PushHook(hook func(context.Context, int64, lfsutil.OID) (*LFSObject, error)) { + f.mutex.Lock() + f.hooks = append(f.hooks, hook) + f.mutex.Unlock() +} + +// SetDefaultReturn calls SetDefaultHook with a function that returns the +// given values. +func (f *LFSStoreGetObjectByOIDFunc) SetDefaultReturn(r0 *LFSObject, r1 error) { + f.SetDefaultHook(func(context.Context, int64, lfsutil.OID) (*LFSObject, error) { + return r0, r1 + }) +} + +// PushReturn calls PushHook with a function that returns the given values. +func (f *LFSStoreGetObjectByOIDFunc) PushReturn(r0 *LFSObject, r1 error) { + f.PushHook(func(context.Context, int64, lfsutil.OID) (*LFSObject, error) { + return r0, r1 + }) +} + +func (f *LFSStoreGetObjectByOIDFunc) nextHook() func(context.Context, int64, lfsutil.OID) (*LFSObject, error) { + f.mutex.Lock() + defer f.mutex.Unlock() + + if len(f.hooks) == 0 { + return f.defaultHook + } + + hook := f.hooks[0] + f.hooks = f.hooks[1:] + return hook +} + +func (f *LFSStoreGetObjectByOIDFunc) appendCall(r0 LFSStoreGetObjectByOIDFuncCall) { + f.mutex.Lock() + f.history = append(f.history, r0) + f.mutex.Unlock() +} + +// History returns a sequence of LFSStoreGetObjectByOIDFuncCall objects +// describing the invocations of this function. +func (f *LFSStoreGetObjectByOIDFunc) History() []LFSStoreGetObjectByOIDFuncCall { + f.mutex.Lock() + history := make([]LFSStoreGetObjectByOIDFuncCall, len(f.history)) + copy(history, f.history) + f.mutex.Unlock() + + return history +} + +// LFSStoreGetObjectByOIDFuncCall is an object that describes an invocation +// of method GetObjectByOID on an instance of MockLFSStore. +type LFSStoreGetObjectByOIDFuncCall struct { + // Arg0 is the value of the 1st argument passed to this method + // invocation. + Arg0 context.Context + // Arg1 is the value of the 2nd argument passed to this method + // invocation. + Arg1 int64 + // Arg2 is the value of the 3rd argument passed to this method + // invocation. + Arg2 lfsutil.OID + // Result0 is the value of the 1st result returned from this method + // invocation. + Result0 *LFSObject + // Result1 is the value of the 2nd result returned from this method + // invocation. + Result1 error +} + +// Args returns an interface slice containing the arguments of this +// invocation. +func (c LFSStoreGetObjectByOIDFuncCall) Args() []interface{} { + return []interface{}{c.Arg0, c.Arg1, c.Arg2} +} + +// Results returns an interface slice containing the results of this +// invocation. +func (c LFSStoreGetObjectByOIDFuncCall) Results() []interface{} { + return []interface{}{c.Result0, c.Result1} +} + +// LFSStoreGetObjectsByOIDsFunc describes the behavior when the +// GetObjectsByOIDs method of the parent MockLFSStore instance is invoked. +type LFSStoreGetObjectsByOIDsFunc struct { + defaultHook func(context.Context, int64, ...lfsutil.OID) ([]*LFSObject, error) + hooks []func(context.Context, int64, ...lfsutil.OID) ([]*LFSObject, error) + history []LFSStoreGetObjectsByOIDsFuncCall + mutex sync.Mutex +} + +// GetObjectsByOIDs delegates to the next hook function in the queue and +// stores the parameter and result values of this invocation. +func (m *MockLFSStore) GetObjectsByOIDs(v0 context.Context, v1 int64, v2 ...lfsutil.OID) ([]*LFSObject, error) { + r0, r1 := m.GetObjectsByOIDsFunc.nextHook()(v0, v1, v2...) + m.GetObjectsByOIDsFunc.appendCall(LFSStoreGetObjectsByOIDsFuncCall{v0, v1, v2, r0, r1}) + return r0, r1 +} + +// SetDefaultHook sets function that is called when the GetObjectsByOIDs +// method of the parent MockLFSStore instance is invoked and the hook queue +// is empty. +func (f *LFSStoreGetObjectsByOIDsFunc) SetDefaultHook(hook func(context.Context, int64, ...lfsutil.OID) ([]*LFSObject, error)) { + f.defaultHook = hook +} + +// PushHook adds a function to the end of hook queue. Each invocation of the +// GetObjectsByOIDs method of the parent MockLFSStore instance invokes the +// hook at the front of the queue and discards it. After the queue is empty, +// the default hook function is invoked for any future action. +func (f *LFSStoreGetObjectsByOIDsFunc) PushHook(hook func(context.Context, int64, ...lfsutil.OID) ([]*LFSObject, error)) { + f.mutex.Lock() + f.hooks = append(f.hooks, hook) + f.mutex.Unlock() +} + +// SetDefaultReturn calls SetDefaultHook with a function that returns the +// given values. +func (f *LFSStoreGetObjectsByOIDsFunc) SetDefaultReturn(r0 []*LFSObject, r1 error) { + f.SetDefaultHook(func(context.Context, int64, ...lfsutil.OID) ([]*LFSObject, error) { + return r0, r1 + }) +} + +// PushReturn calls PushHook with a function that returns the given values. +func (f *LFSStoreGetObjectsByOIDsFunc) PushReturn(r0 []*LFSObject, r1 error) { + f.PushHook(func(context.Context, int64, ...lfsutil.OID) ([]*LFSObject, error) { + return r0, r1 + }) +} + +func (f *LFSStoreGetObjectsByOIDsFunc) nextHook() func(context.Context, int64, ...lfsutil.OID) ([]*LFSObject, error) { + f.mutex.Lock() + defer f.mutex.Unlock() + + if len(f.hooks) == 0 { + return f.defaultHook + } + + hook := f.hooks[0] + f.hooks = f.hooks[1:] + return hook +} + +func (f *LFSStoreGetObjectsByOIDsFunc) appendCall(r0 LFSStoreGetObjectsByOIDsFuncCall) { + f.mutex.Lock() + f.history = append(f.history, r0) + f.mutex.Unlock() +} + +// History returns a sequence of LFSStoreGetObjectsByOIDsFuncCall objects +// describing the invocations of this function. +func (f *LFSStoreGetObjectsByOIDsFunc) History() []LFSStoreGetObjectsByOIDsFuncCall { + f.mutex.Lock() + history := make([]LFSStoreGetObjectsByOIDsFuncCall, len(f.history)) + copy(history, f.history) + f.mutex.Unlock() + + return history +} + +// LFSStoreGetObjectsByOIDsFuncCall is an object that describes an +// invocation of method GetObjectsByOIDs on an instance of MockLFSStore. +type LFSStoreGetObjectsByOIDsFuncCall struct { + // Arg0 is the value of the 1st argument passed to this method + // invocation. + Arg0 context.Context + // Arg1 is the value of the 2nd argument passed to this method + // invocation. + Arg1 int64 + // Arg2 is a slice containing the values of the variadic arguments + // passed to this method invocation. + Arg2 []lfsutil.OID + // Result0 is the value of the 1st result returned from this method + // invocation. + Result0 []*LFSObject + // Result1 is the value of the 2nd result returned from this method + // invocation. + Result1 error +} + +// Args returns an interface slice containing the arguments of this +// invocation. The variadic slice argument is flattened in this array such +// that one positional argument and three variadic arguments would result in +// a slice of four, not two. +func (c LFSStoreGetObjectsByOIDsFuncCall) Args() []interface{} { + trailing := []interface{}{} + for _, val := range c.Arg2 { + trailing = append(trailing, val) + } + + return append([]interface{}{c.Arg0, c.Arg1}, trailing...) +} + +// Results returns an interface slice containing the results of this +// invocation. +func (c LFSStoreGetObjectsByOIDsFuncCall) Results() []interface{} { + return []interface{}{c.Result0, c.Result1} +} + // MockPermsStore is a mock implementation of the PermsStore interface (from // the package gogs.io/gogs/internal/db) used for unit testing. type MockPermsStore struct { diff --git a/internal/route/lfs/basic.go b/internal/route/lfs/basic.go index a0594839b..cbfc724fe 100644 --- a/internal/route/lfs/basic.go +++ b/internal/route/lfs/basic.go @@ -44,7 +44,7 @@ func (h *basicHandler) Storager(storage lfsutil.Storage) lfsutil.Storager { // GET /{owner}/{repo}.git/info/lfs/object/basic/{oid} func (h *basicHandler) serveDownload(c *macaron.Context, repo *db.Repository, oid lfsutil.OID) { - object, err := db.LFS.GetObjectByOID(repo.ID, oid) + object, err := db.LFS.GetObjectByOID(c.Req.Context(), repo.ID, oid) if err != nil { if db.IsErrLFSObjectNotExist(err) { responseJSON(c.Resp, http.StatusNotFound, responseError{ @@ -79,7 +79,7 @@ func (h *basicHandler) serveDownload(c *macaron.Context, repo *db.Repository, oi func (h *basicHandler) serveUpload(c *macaron.Context, repo *db.Repository, oid lfsutil.OID) { // NOTE: LFS client will retry upload the same object if there was a partial failure, // therefore we would like to skip ones that already exist. - _, err := db.LFS.GetObjectByOID(repo.ID, oid) + _, err := db.LFS.GetObjectByOID(c.Req.Context(), repo.ID, oid) if err == nil { // Object exists, drain the request body and we're good. _, _ = io.Copy(ioutil.Discard, c.Req.Request.Body) @@ -106,7 +106,7 @@ func (h *basicHandler) serveUpload(c *macaron.Context, repo *db.Repository, oid return } - err = db.LFS.CreateObject(repo.ID, oid, written, s.Storage()) + err = db.LFS.CreateObject(c.Req.Context(), repo.ID, oid, written, s.Storage()) if err != nil { // NOTE: It is OK to leave the file when the whole operation failed // with a DB error, a retry on client side can safely overwrite the @@ -139,7 +139,7 @@ func (*basicHandler) serveVerify(c *macaron.Context, repo *db.Repository) { return } - object, err := db.LFS.GetObjectByOID(repo.ID, request.Oid) + object, err := db.LFS.GetObjectByOID(c.Req.Context(), repo.ID, request.Oid) if err != nil { if db.IsErrLFSObjectNotExist(err) { responseJSON(c.Resp, http.StatusNotFound, responseError{ diff --git a/internal/route/lfs/basic_test.go b/internal/route/lfs/basic_test.go index 6343fdcf3..85d03188a 100644 --- a/internal/route/lfs/basic_test.go +++ b/internal/route/lfs/basic_test.go @@ -61,17 +61,17 @@ func Test_basicHandler_serveDownload(t *testing.T) { tests := []struct { name string content string - mockLFSStore *db.MockLFSStore + mockLFSStore func() db.LFSStore expStatusCode int expHeader http.Header expBody string }{ { name: "object does not exist", - mockLFSStore: &db.MockLFSStore{ - MockGetObjectByOID: func(repoID int64, oid lfsutil.OID) (*db.LFSObject, error) { - return nil, db.ErrLFSObjectNotExist{} - }, + mockLFSStore: func() db.LFSStore { + mock := db.NewMockLFSStore() + mock.GetObjectByOIDFunc.SetDefaultReturn(nil, db.ErrLFSObjectNotExist{}) + return mock }, expStatusCode: http.StatusNotFound, expHeader: http.Header{ @@ -81,10 +81,10 @@ func Test_basicHandler_serveDownload(t *testing.T) { }, { name: "storage not found", - mockLFSStore: &db.MockLFSStore{ - MockGetObjectByOID: func(repoID int64, oid lfsutil.OID) (*db.LFSObject, error) { - return &db.LFSObject{Storage: "bad_storage"}, nil - }, + mockLFSStore: func() db.LFSStore { + mock := db.NewMockLFSStore() + mock.GetObjectByOIDFunc.SetDefaultReturn(&db.LFSObject{Storage: "bad_storage"}, nil) + return mock }, expStatusCode: http.StatusInternalServerError, expHeader: http.Header{ @@ -96,13 +96,16 @@ func Test_basicHandler_serveDownload(t *testing.T) { { name: "object exists", content: "Hello world!", - mockLFSStore: &db.MockLFSStore{ - MockGetObjectByOID: func(repoID int64, oid lfsutil.OID) (*db.LFSObject, error) { - return &db.LFSObject{ + mockLFSStore: func() db.LFSStore { + mock := db.NewMockLFSStore() + mock.GetObjectByOIDFunc.SetDefaultReturn( + &db.LFSObject{ Size: 12, Storage: s.Storage(), - }, nil - }, + }, + nil, + ) + return mock }, expStatusCode: http.StatusOK, expHeader: http.Header{ @@ -114,7 +117,7 @@ func Test_basicHandler_serveDownload(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - db.SetMockLFSStore(t, test.mockLFSStore) + db.SetMockLFSStore(t, test.mockLFSStore()) s.buf = bytes.NewBufferString(test.content) @@ -158,35 +161,32 @@ func Test_basicHandler_serveUpload(t *testing.T) { tests := []struct { name string - mockLFSStore *db.MockLFSStore + mockLFSStore func() db.LFSStore expStatusCode int expBody string }{ { name: "object already exists", - mockLFSStore: &db.MockLFSStore{ - MockGetObjectByOID: func(repoID int64, oid lfsutil.OID) (*db.LFSObject, error) { - return &db.LFSObject{}, nil - }, + mockLFSStore: func() db.LFSStore { + mock := db.NewMockLFSStore() + mock.GetObjectByOIDFunc.SetDefaultReturn(&db.LFSObject{}, nil) + return mock }, expStatusCode: http.StatusOK, }, { name: "new object", - mockLFSStore: &db.MockLFSStore{ - MockGetObjectByOID: func(repoID int64, oid lfsutil.OID) (*db.LFSObject, error) { - return nil, db.ErrLFSObjectNotExist{} - }, - MockCreateObject: func(repoID int64, oid lfsutil.OID, size int64, storage lfsutil.Storage) error { - return nil - }, + mockLFSStore: func() db.LFSStore { + mock := db.NewMockLFSStore() + mock.GetObjectByOIDFunc.SetDefaultReturn(nil, db.ErrLFSObjectNotExist{}) + return mock }, expStatusCode: http.StatusOK, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - db.SetMockLFSStore(t, test.mockLFSStore) + db.SetMockLFSStore(t, test.mockLFSStore()) r, err := http.NewRequest("PUT", "/", strings.NewReader("Hello world!")) if err != nil { @@ -219,7 +219,7 @@ func Test_basicHandler_serveVerify(t *testing.T) { tests := []struct { name string body string - mockLFSStore *db.MockLFSStore + mockLFSStore func() db.LFSStore expStatusCode int expBody string }{ @@ -232,10 +232,10 @@ func Test_basicHandler_serveVerify(t *testing.T) { { name: "object does not exist", body: `{"oid":"ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f"}`, - mockLFSStore: &db.MockLFSStore{ - MockGetObjectByOID: func(repoID int64, oid lfsutil.OID) (*db.LFSObject, error) { - return nil, db.ErrLFSObjectNotExist{} - }, + mockLFSStore: func() db.LFSStore { + mock := db.NewMockLFSStore() + mock.GetObjectByOIDFunc.SetDefaultReturn(nil, db.ErrLFSObjectNotExist{}) + return mock }, expStatusCode: http.StatusNotFound, expBody: `{"message":"Object does not exist"}` + "\n", @@ -243,10 +243,10 @@ func Test_basicHandler_serveVerify(t *testing.T) { { name: "object size mismatch", body: `{"oid":"ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f"}`, - mockLFSStore: &db.MockLFSStore{ - MockGetObjectByOID: func(repoID int64, oid lfsutil.OID) (*db.LFSObject, error) { - return &db.LFSObject{Size: 12}, nil - }, + mockLFSStore: func() db.LFSStore { + mock := db.NewMockLFSStore() + mock.GetObjectByOIDFunc.SetDefaultReturn(&db.LFSObject{Size: 12}, nil) + return mock }, expStatusCode: http.StatusBadRequest, expBody: `{"message":"Object size mismatch"}` + "\n", @@ -255,17 +255,19 @@ func Test_basicHandler_serveVerify(t *testing.T) { { name: "object exists", body: `{"oid":"ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f", "size":12}`, - mockLFSStore: &db.MockLFSStore{ - MockGetObjectByOID: func(repoID int64, oid lfsutil.OID) (*db.LFSObject, error) { - return &db.LFSObject{Size: 12}, nil - }, + mockLFSStore: func() db.LFSStore { + mock := db.NewMockLFSStore() + mock.GetObjectByOIDFunc.SetDefaultReturn(&db.LFSObject{Size: 12}, nil) + return mock }, expStatusCode: http.StatusOK, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - db.SetMockLFSStore(t, test.mockLFSStore) + if test.mockLFSStore != nil { + db.SetMockLFSStore(t, test.mockLFSStore()) + } r, err := http.NewRequest("POST", "/", strings.NewReader(test.body)) if err != nil { diff --git a/internal/route/lfs/batch.go b/internal/route/lfs/batch.go index bfc364c20..bde3140d8 100644 --- a/internal/route/lfs/batch.go +++ b/internal/route/lfs/batch.go @@ -75,7 +75,7 @@ func serveBatch(c *macaron.Context, owner *db.User, repo *db.Repository) { for _, obj := range request.Objects { oids = append(oids, obj.Oid) } - stored, err := db.LFS.GetObjectsByOIDs(repo.ID, oids...) + stored, err := db.LFS.GetObjectsByOIDs(c.Req.Context(), repo.ID, oids...) if err != nil { internalServerError(c.Resp) log.Error("Failed to get objects [repo_id: %d, oids: %v]: %v", repo.ID, oids, err) diff --git a/internal/route/lfs/batch_test.go b/internal/route/lfs/batch_test.go index 67b85eeb0..76c0a8174 100644 --- a/internal/route/lfs/batch_test.go +++ b/internal/route/lfs/batch_test.go @@ -17,7 +17,6 @@ import ( "gogs.io/gogs/internal/conf" "gogs.io/gogs/internal/db" - "gogs.io/gogs/internal/lfsutil" ) func Test_serveBatch(t *testing.T) { @@ -35,7 +34,7 @@ func Test_serveBatch(t *testing.T) { tests := []struct { name string body string - mockLFSStore *db.MockLFSStore + mockLFSStore func() db.LFSStore expStatusCode int expBody string }{ @@ -83,9 +82,10 @@ func Test_serveBatch(t *testing.T) { {"oid": "ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f", "size": 123}, {"oid": "5cac0a318669fadfee734fb340a5f5b70b428ac57a9f4b109cb6e150b2ba7e57", "size": 456} ]}`, - mockLFSStore: &db.MockLFSStore{ - MockGetObjectsByOIDs: func(repoID int64, oids ...lfsutil.OID) ([]*db.LFSObject, error) { - return []*db.LFSObject{ + mockLFSStore: func() db.LFSStore { + mock := db.NewMockLFSStore() + mock.GetObjectsByOIDsFunc.SetDefaultReturn( + []*db.LFSObject{ { OID: "ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f", Size: 1234, @@ -93,8 +93,10 @@ func Test_serveBatch(t *testing.T) { OID: "5cac0a318669fadfee734fb340a5f5b70b428ac57a9f4b109cb6e150b2ba7e57", Size: 456, }, - }, nil - }, + }, + nil, + ) + return mock }, expStatusCode: http.StatusOK, expBody: `{ @@ -121,7 +123,9 @@ func Test_serveBatch(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - db.SetMockLFSStore(t, test.mockLFSStore) + if test.mockLFSStore != nil { + db.SetMockLFSStore(t, test.mockLFSStore()) + } r, err := http.NewRequest("POST", "/", bytes.NewBufferString(test.body)) if err != nil {