From 5d083e35f0d2b0c31e194b2726f6e2c146001835 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sun, 17 Jan 2021 10:54:21 -0300 Subject: [PATCH] Add tests for the Transaction function --- contracts.go | 2 +- examples/example_service/mocks.go | 14 ++++++ kiss_orm.go | 23 +++++++--- kiss_orm_test.go | 72 +++++++++++++++++++++++++++++++ mocks.go | 9 +++- 5 files changed, 112 insertions(+), 8 deletions(-) diff --git a/contracts.go b/contracts.go index d90d673..66962a5 100644 --- a/contracts.go +++ b/contracts.go @@ -22,7 +22,7 @@ type ORMProvider interface { QueryChunks(ctx context.Context, parser ChunkParser) error Exec(ctx context.Context, query string, params ...interface{}) error - Transaction(ctx context.Context, fn func(ORMProvider) error) (err error) + Transaction(ctx context.Context, fn func(ORMProvider) error) error } // ChunkParser stores the arguments of the QueryChunks function diff --git a/examples/example_service/mocks.go b/examples/example_service/mocks.go index af57e07..21ad1b5 100644 --- a/examples/example_service/mocks.go +++ b/examples/example_service/mocks.go @@ -144,6 +144,20 @@ func (mr *MockORMProviderMockRecorder) QueryOne(ctx, record, query interface{}, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryOne", reflect.TypeOf((*MockORMProvider)(nil).QueryOne), varargs...) } +// Transaction mocks base method. +func (m *MockORMProvider) Transaction(ctx context.Context, fn func(kissorm.ORMProvider) error) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Transaction", ctx, fn) + ret0, _ := ret[0].(error) + return ret0 +} + +// Transaction indicates an expected call of Transaction. +func (mr *MockORMProviderMockRecorder) Transaction(ctx, fn interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Transaction", reflect.TypeOf((*MockORMProvider)(nil).Transaction), ctx, fn) +} + // Update mocks base method. func (m *MockORMProvider) Update(ctx context.Context, records ...interface{}) error { m.ctrl.T.Helper() diff --git a/kiss_orm.go b/kiss_orm.go index 1cf5def..4e4a00a 100644 --- a/kiss_orm.go +++ b/kiss_orm.go @@ -6,6 +6,8 @@ import ( "fmt" "reflect" "strings" + + "github.com/pkg/errors" ) // DB ... @@ -504,19 +506,23 @@ func (c DB) Exec(ctx context.Context, query string, params ...interface{}) error } // Transaction just runs an SQL command on the database returning no rows. -func (c DB) Transaction(ctx context.Context, fn func(ORMProvider) error) (err error) { +func (c DB) Transaction(ctx context.Context, fn func(ORMProvider) error) error { switch db := c.db.(type) { case *sql.Tx: return fn(c) case *sql.DB: - var tx *sql.Tx - tx, err = db.BeginTx(ctx, nil) + tx, err := db.BeginTx(ctx, nil) if err != nil { return err } defer func() { if r := recover(); r != nil { - _ = tx.Rollback() + rollbackErr := tx.Rollback() + if rollbackErr != nil { + r = errors.Wrap(rollbackErr, + fmt.Sprintf("unable to rollback after panic with value: %v", r), + ) + } panic(r) } }() @@ -526,14 +532,19 @@ func (c DB) Transaction(ctx context.Context, fn func(ORMProvider) error) (err er err = fn(ormCopy) if err != nil { - _ = tx.Rollback() + rollbackErr := tx.Rollback() + if rollbackErr != nil { + err = errors.Wrap(rollbackErr, + fmt.Sprintf("unable to rollback after error: %s", err.Error()), + ) + } return err } return tx.Commit() default: - return fmt.Errorf("unexpected error on kissorm: db has an invalid type") + return fmt.Errorf("unexpected error on kissorm: db attribute has an invalid type") } } diff --git a/kiss_orm_test.go b/kiss_orm_test.go index a21d597..063490e 100644 --- a/kiss_orm_test.go +++ b/kiss_orm_test.go @@ -3,6 +3,7 @@ package kissorm import ( "context" "database/sql" + "errors" "fmt" "strings" "testing" @@ -795,6 +796,77 @@ func TestQueryChunks(t *testing.T) { } } +func TestTransaction(t *testing.T) { + for _, driver := range []string{"sqlite3", "postgres"} { + t.Run(driver, func(t *testing.T) { + t.Run("should query a single row correctly", func(t *testing.T) { + err := createTable(driver) + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) + } + + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestDB(db, driver, "users") + + _ = c.Insert(ctx, &User{Name: "User1"}) + _ = c.Insert(ctx, &User{Name: "User2"}) + + var users []User + err = c.Transaction(ctx, func(db ORMProvider) error { + db.Query(ctx, &users, "SELECT * FROM users ORDER BY id ASC") + return nil + }) + assert.Equal(t, nil, err) + + assert.Equal(t, 2, len(users)) + assert.Equal(t, "User1", users[0].Name) + assert.Equal(t, "User2", users[1].Name) + }) + + t.Run("should rollback when there are errors", func(t *testing.T) { + err := createTable(driver) + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) + } + + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestDB(db, driver, "users") + + u1 := User{Name: "User1", Age: 42} + u2 := User{Name: "User2", Age: 42} + _ = c.Insert(ctx, &u1) + _ = c.Insert(ctx, &u2) + + err = c.Transaction(ctx, func(db ORMProvider) error { + fmt.Printf("received db client: %#v\n", db) + err = db.Insert(ctx, &User{Name: "User3"}) + assert.Equal(t, nil, err) + err = db.Insert(ctx, &User{Name: "User4"}) + assert.Equal(t, nil, err) + err = db.Exec(ctx, "UPDATE users SET age = 22") + assert.Equal(t, nil, err) + + return errors.New("fake-error") + }) + assert.NotEqual(t, nil, err) + assert.Equal(t, "fake-error", err.Error()) + + var users []User + err = c.Query(ctx, &users, "SELECT * FROM users ORDER BY id ASC") + assert.Equal(t, nil, err) + + assert.Equal(t, []User{u1, u2}, users) + }) + }) + } +} + func TestScanRows(t *testing.T) { t.Run("should scan users correctly", func(t *testing.T) { err := createTable("sqlite3") diff --git a/mocks.go b/mocks.go index 9ab1688..664c56b 100644 --- a/mocks.go +++ b/mocks.go @@ -2,6 +2,8 @@ package kissorm import "context" +var _ ORMProvider = MockORMProvider{} + type MockORMProvider struct { InsertFn func(ctx context.Context, records ...interface{}) error DeleteFn func(ctx context.Context, ids ...interface{}) error @@ -11,7 +13,8 @@ type MockORMProvider struct { QueryOneFn func(ctx context.Context, record interface{}, query string, params ...interface{}) error QueryChunksFn func(ctx context.Context, parser ChunkParser) error - ExecFn func(ctx context.Context, query string, params ...interface{}) error + ExecFn func(ctx context.Context, query string, params ...interface{}) error + TransactionFn func(ctx context.Context, fn func(db ORMProvider) error) error } func (m MockORMProvider) Insert(ctx context.Context, records ...interface{}) error { @@ -41,3 +44,7 @@ func (m MockORMProvider) QueryChunks(ctx context.Context, parser ChunkParser) er func (m MockORMProvider) Exec(ctx context.Context, query string, params ...interface{}) error { return m.ExecFn(ctx, query, params...) } + +func (m MockORMProvider) Transaction(ctx context.Context, fn func(db ORMProvider) error) error { + return m.TransactionFn(ctx, fn) +}