mirror of https://github.com/VinGarcia/ksql.git
Add tests for the Transaction function
parent
479e47b018
commit
5d083e35f0
|
@ -22,7 +22,7 @@ type ORMProvider interface {
|
|||
QueryChunks(ctx context.Context, parser ChunkParser) error
|
||||
|
||||
Exec(ctx context.Context, query string, params ...interface{}) error
|
||||
Transaction(ctx context.Context, fn func(ORMProvider) error) (err error)
|
||||
Transaction(ctx context.Context, fn func(ORMProvider) error) error
|
||||
}
|
||||
|
||||
// ChunkParser stores the arguments of the QueryChunks function
|
||||
|
|
|
@ -144,6 +144,20 @@ func (mr *MockORMProviderMockRecorder) QueryOne(ctx, record, query interface{},
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryOne", reflect.TypeOf((*MockORMProvider)(nil).QueryOne), varargs...)
|
||||
}
|
||||
|
||||
// Transaction mocks base method.
|
||||
func (m *MockORMProvider) Transaction(ctx context.Context, fn func(kissorm.ORMProvider) error) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Transaction", ctx, fn)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Transaction indicates an expected call of Transaction.
|
||||
func (mr *MockORMProviderMockRecorder) Transaction(ctx, fn interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Transaction", reflect.TypeOf((*MockORMProvider)(nil).Transaction), ctx, fn)
|
||||
}
|
||||
|
||||
// Update mocks base method.
|
||||
func (m *MockORMProvider) Update(ctx context.Context, records ...interface{}) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
23
kiss_orm.go
23
kiss_orm.go
|
@ -6,6 +6,8 @@ import (
|
|||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// DB ...
|
||||
|
@ -504,19 +506,23 @@ func (c DB) Exec(ctx context.Context, query string, params ...interface{}) error
|
|||
}
|
||||
|
||||
// Transaction just runs an SQL command on the database returning no rows.
|
||||
func (c DB) Transaction(ctx context.Context, fn func(ORMProvider) error) (err error) {
|
||||
func (c DB) Transaction(ctx context.Context, fn func(ORMProvider) error) error {
|
||||
switch db := c.db.(type) {
|
||||
case *sql.Tx:
|
||||
return fn(c)
|
||||
case *sql.DB:
|
||||
var tx *sql.Tx
|
||||
tx, err = db.BeginTx(ctx, nil)
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
_ = tx.Rollback()
|
||||
rollbackErr := tx.Rollback()
|
||||
if rollbackErr != nil {
|
||||
r = errors.Wrap(rollbackErr,
|
||||
fmt.Sprintf("unable to rollback after panic with value: %v", r),
|
||||
)
|
||||
}
|
||||
panic(r)
|
||||
}
|
||||
}()
|
||||
|
@ -526,14 +532,19 @@ func (c DB) Transaction(ctx context.Context, fn func(ORMProvider) error) (err er
|
|||
|
||||
err = fn(ormCopy)
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
rollbackErr := tx.Rollback()
|
||||
if rollbackErr != nil {
|
||||
err = errors.Wrap(rollbackErr,
|
||||
fmt.Sprintf("unable to rollback after error: %s", err.Error()),
|
||||
)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unexpected error on kissorm: db has an invalid type")
|
||||
return fmt.Errorf("unexpected error on kissorm: db attribute has an invalid type")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ package kissorm
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
@ -795,6 +796,77 @@ func TestQueryChunks(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestTransaction(t *testing.T) {
|
||||
for _, driver := range []string{"sqlite3", "postgres"} {
|
||||
t.Run(driver, func(t *testing.T) {
|
||||
t.Run("should query a single row correctly", func(t *testing.T) {
|
||||
err := createTable(driver)
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!, reason:", err.Error())
|
||||
}
|
||||
|
||||
db := connectDB(t, driver)
|
||||
defer db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
c := newTestDB(db, driver, "users")
|
||||
|
||||
_ = c.Insert(ctx, &User{Name: "User1"})
|
||||
_ = c.Insert(ctx, &User{Name: "User2"})
|
||||
|
||||
var users []User
|
||||
err = c.Transaction(ctx, func(db ORMProvider) error {
|
||||
db.Query(ctx, &users, "SELECT * FROM users ORDER BY id ASC")
|
||||
return nil
|
||||
})
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
assert.Equal(t, 2, len(users))
|
||||
assert.Equal(t, "User1", users[0].Name)
|
||||
assert.Equal(t, "User2", users[1].Name)
|
||||
})
|
||||
|
||||
t.Run("should rollback when there are errors", func(t *testing.T) {
|
||||
err := createTable(driver)
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!, reason:", err.Error())
|
||||
}
|
||||
|
||||
db := connectDB(t, driver)
|
||||
defer db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
c := newTestDB(db, driver, "users")
|
||||
|
||||
u1 := User{Name: "User1", Age: 42}
|
||||
u2 := User{Name: "User2", Age: 42}
|
||||
_ = c.Insert(ctx, &u1)
|
||||
_ = c.Insert(ctx, &u2)
|
||||
|
||||
err = c.Transaction(ctx, func(db ORMProvider) error {
|
||||
fmt.Printf("received db client: %#v\n", db)
|
||||
err = db.Insert(ctx, &User{Name: "User3"})
|
||||
assert.Equal(t, nil, err)
|
||||
err = db.Insert(ctx, &User{Name: "User4"})
|
||||
assert.Equal(t, nil, err)
|
||||
err = db.Exec(ctx, "UPDATE users SET age = 22")
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
return errors.New("fake-error")
|
||||
})
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, "fake-error", err.Error())
|
||||
|
||||
var users []User
|
||||
err = c.Query(ctx, &users, "SELECT * FROM users ORDER BY id ASC")
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
assert.Equal(t, []User{u1, u2}, users)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanRows(t *testing.T) {
|
||||
t.Run("should scan users correctly", func(t *testing.T) {
|
||||
err := createTable("sqlite3")
|
||||
|
|
9
mocks.go
9
mocks.go
|
@ -2,6 +2,8 @@ package kissorm
|
|||
|
||||
import "context"
|
||||
|
||||
var _ ORMProvider = MockORMProvider{}
|
||||
|
||||
type MockORMProvider struct {
|
||||
InsertFn func(ctx context.Context, records ...interface{}) error
|
||||
DeleteFn func(ctx context.Context, ids ...interface{}) error
|
||||
|
@ -11,7 +13,8 @@ type MockORMProvider struct {
|
|||
QueryOneFn func(ctx context.Context, record interface{}, query string, params ...interface{}) error
|
||||
QueryChunksFn func(ctx context.Context, parser ChunkParser) error
|
||||
|
||||
ExecFn func(ctx context.Context, query string, params ...interface{}) error
|
||||
ExecFn func(ctx context.Context, query string, params ...interface{}) error
|
||||
TransactionFn func(ctx context.Context, fn func(db ORMProvider) error) error
|
||||
}
|
||||
|
||||
func (m MockORMProvider) Insert(ctx context.Context, records ...interface{}) error {
|
||||
|
@ -41,3 +44,7 @@ func (m MockORMProvider) QueryChunks(ctx context.Context, parser ChunkParser) er
|
|||
func (m MockORMProvider) Exec(ctx context.Context, query string, params ...interface{}) error {
|
||||
return m.ExecFn(ctx, query, params...)
|
||||
}
|
||||
|
||||
func (m MockORMProvider) Transaction(ctx context.Context, fn func(db ORMProvider) error) error {
|
||||
return m.TransactionFn(ctx, fn)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue