Add tests for the Transaction function

pull/2/head
Vinícius Garcia 2021-01-17 10:54:21 -03:00
parent 479e47b018
commit 5d083e35f0
5 changed files with 112 additions and 8 deletions

View File

@ -22,7 +22,7 @@ type ORMProvider interface {
QueryChunks(ctx context.Context, parser ChunkParser) error
Exec(ctx context.Context, query string, params ...interface{}) error
Transaction(ctx context.Context, fn func(ORMProvider) error) (err error)
Transaction(ctx context.Context, fn func(ORMProvider) error) error
}
// ChunkParser stores the arguments of the QueryChunks function

View File

@ -144,6 +144,20 @@ func (mr *MockORMProviderMockRecorder) QueryOne(ctx, record, query interface{},
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryOne", reflect.TypeOf((*MockORMProvider)(nil).QueryOne), varargs...)
}
// Transaction mocks base method.
func (m *MockORMProvider) Transaction(ctx context.Context, fn func(kissorm.ORMProvider) error) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Transaction", ctx, fn)
ret0, _ := ret[0].(error)
return ret0
}
// Transaction indicates an expected call of Transaction.
func (mr *MockORMProviderMockRecorder) Transaction(ctx, fn interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Transaction", reflect.TypeOf((*MockORMProvider)(nil).Transaction), ctx, fn)
}
// Update mocks base method.
func (m *MockORMProvider) Update(ctx context.Context, records ...interface{}) error {
m.ctrl.T.Helper()

View File

@ -6,6 +6,8 @@ import (
"fmt"
"reflect"
"strings"
"github.com/pkg/errors"
)
// DB ...
@ -504,19 +506,23 @@ func (c DB) Exec(ctx context.Context, query string, params ...interface{}) error
}
// Transaction just runs an SQL command on the database returning no rows.
func (c DB) Transaction(ctx context.Context, fn func(ORMProvider) error) (err error) {
func (c DB) Transaction(ctx context.Context, fn func(ORMProvider) error) error {
switch db := c.db.(type) {
case *sql.Tx:
return fn(c)
case *sql.DB:
var tx *sql.Tx
tx, err = db.BeginTx(ctx, nil)
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer func() {
if r := recover(); r != nil {
_ = tx.Rollback()
rollbackErr := tx.Rollback()
if rollbackErr != nil {
r = errors.Wrap(rollbackErr,
fmt.Sprintf("unable to rollback after panic with value: %v", r),
)
}
panic(r)
}
}()
@ -526,14 +532,19 @@ func (c DB) Transaction(ctx context.Context, fn func(ORMProvider) error) (err er
err = fn(ormCopy)
if err != nil {
_ = tx.Rollback()
rollbackErr := tx.Rollback()
if rollbackErr != nil {
err = errors.Wrap(rollbackErr,
fmt.Sprintf("unable to rollback after error: %s", err.Error()),
)
}
return err
}
return tx.Commit()
default:
return fmt.Errorf("unexpected error on kissorm: db has an invalid type")
return fmt.Errorf("unexpected error on kissorm: db attribute has an invalid type")
}
}

View File

@ -3,6 +3,7 @@ package kissorm
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"testing"
@ -795,6 +796,77 @@ func TestQueryChunks(t *testing.T) {
}
}
func TestTransaction(t *testing.T) {
for _, driver := range []string{"sqlite3", "postgres"} {
t.Run(driver, func(t *testing.T) {
t.Run("should query a single row correctly", func(t *testing.T) {
err := createTable(driver)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
db := connectDB(t, driver)
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
_ = c.Insert(ctx, &User{Name: "User1"})
_ = c.Insert(ctx, &User{Name: "User2"})
var users []User
err = c.Transaction(ctx, func(db ORMProvider) error {
db.Query(ctx, &users, "SELECT * FROM users ORDER BY id ASC")
return nil
})
assert.Equal(t, nil, err)
assert.Equal(t, 2, len(users))
assert.Equal(t, "User1", users[0].Name)
assert.Equal(t, "User2", users[1].Name)
})
t.Run("should rollback when there are errors", func(t *testing.T) {
err := createTable(driver)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
db := connectDB(t, driver)
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
u1 := User{Name: "User1", Age: 42}
u2 := User{Name: "User2", Age: 42}
_ = c.Insert(ctx, &u1)
_ = c.Insert(ctx, &u2)
err = c.Transaction(ctx, func(db ORMProvider) error {
fmt.Printf("received db client: %#v\n", db)
err = db.Insert(ctx, &User{Name: "User3"})
assert.Equal(t, nil, err)
err = db.Insert(ctx, &User{Name: "User4"})
assert.Equal(t, nil, err)
err = db.Exec(ctx, "UPDATE users SET age = 22")
assert.Equal(t, nil, err)
return errors.New("fake-error")
})
assert.NotEqual(t, nil, err)
assert.Equal(t, "fake-error", err.Error())
var users []User
err = c.Query(ctx, &users, "SELECT * FROM users ORDER BY id ASC")
assert.Equal(t, nil, err)
assert.Equal(t, []User{u1, u2}, users)
})
})
}
}
func TestScanRows(t *testing.T) {
t.Run("should scan users correctly", func(t *testing.T) {
err := createTable("sqlite3")

View File

@ -2,6 +2,8 @@ package kissorm
import "context"
var _ ORMProvider = MockORMProvider{}
type MockORMProvider struct {
InsertFn func(ctx context.Context, records ...interface{}) error
DeleteFn func(ctx context.Context, ids ...interface{}) error
@ -11,7 +13,8 @@ type MockORMProvider struct {
QueryOneFn func(ctx context.Context, record interface{}, query string, params ...interface{}) error
QueryChunksFn func(ctx context.Context, parser ChunkParser) error
ExecFn func(ctx context.Context, query string, params ...interface{}) error
ExecFn func(ctx context.Context, query string, params ...interface{}) error
TransactionFn func(ctx context.Context, fn func(db ORMProvider) error) error
}
func (m MockORMProvider) Insert(ctx context.Context, records ...interface{}) error {
@ -41,3 +44,7 @@ func (m MockORMProvider) QueryChunks(ctx context.Context, parser ChunkParser) er
func (m MockORMProvider) Exec(ctx context.Context, query string, params ...interface{}) error {
return m.ExecFn(ctx, query, params...)
}
func (m MockORMProvider) Transaction(ctx context.Context, fn func(db ORMProvider) error) error {
return m.TransactionFn(ctx, fn)
}