mirror of https://github.com/VinGarcia/ksql.git
Add tests for the Transaction function
parent
479e47b018
commit
5d083e35f0
|
@ -22,7 +22,7 @@ type ORMProvider interface {
|
||||||
QueryChunks(ctx context.Context, parser ChunkParser) error
|
QueryChunks(ctx context.Context, parser ChunkParser) error
|
||||||
|
|
||||||
Exec(ctx context.Context, query string, params ...interface{}) error
|
Exec(ctx context.Context, query string, params ...interface{}) error
|
||||||
Transaction(ctx context.Context, fn func(ORMProvider) error) (err error)
|
Transaction(ctx context.Context, fn func(ORMProvider) error) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChunkParser stores the arguments of the QueryChunks function
|
// ChunkParser stores the arguments of the QueryChunks function
|
||||||
|
|
|
@ -144,6 +144,20 @@ func (mr *MockORMProviderMockRecorder) QueryOne(ctx, record, query interface{},
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryOne", reflect.TypeOf((*MockORMProvider)(nil).QueryOne), varargs...)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryOne", reflect.TypeOf((*MockORMProvider)(nil).QueryOne), varargs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Transaction mocks base method.
|
||||||
|
func (m *MockORMProvider) Transaction(ctx context.Context, fn func(kissorm.ORMProvider) error) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "Transaction", ctx, fn)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transaction indicates an expected call of Transaction.
|
||||||
|
func (mr *MockORMProviderMockRecorder) Transaction(ctx, fn interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Transaction", reflect.TypeOf((*MockORMProvider)(nil).Transaction), ctx, fn)
|
||||||
|
}
|
||||||
|
|
||||||
// Update mocks base method.
|
// Update mocks base method.
|
||||||
func (m *MockORMProvider) Update(ctx context.Context, records ...interface{}) error {
|
func (m *MockORMProvider) Update(ctx context.Context, records ...interface{}) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|
23
kiss_orm.go
23
kiss_orm.go
|
@ -6,6 +6,8 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DB ...
|
// DB ...
|
||||||
|
@ -504,19 +506,23 @@ func (c DB) Exec(ctx context.Context, query string, params ...interface{}) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Transaction just runs an SQL command on the database returning no rows.
|
// Transaction just runs an SQL command on the database returning no rows.
|
||||||
func (c DB) Transaction(ctx context.Context, fn func(ORMProvider) error) (err error) {
|
func (c DB) Transaction(ctx context.Context, fn func(ORMProvider) error) error {
|
||||||
switch db := c.db.(type) {
|
switch db := c.db.(type) {
|
||||||
case *sql.Tx:
|
case *sql.Tx:
|
||||||
return fn(c)
|
return fn(c)
|
||||||
case *sql.DB:
|
case *sql.DB:
|
||||||
var tx *sql.Tx
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
tx, err = db.BeginTx(ctx, nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
_ = tx.Rollback()
|
rollbackErr := tx.Rollback()
|
||||||
|
if rollbackErr != nil {
|
||||||
|
r = errors.Wrap(rollbackErr,
|
||||||
|
fmt.Sprintf("unable to rollback after panic with value: %v", r),
|
||||||
|
)
|
||||||
|
}
|
||||||
panic(r)
|
panic(r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -526,14 +532,19 @@ func (c DB) Transaction(ctx context.Context, fn func(ORMProvider) error) (err er
|
||||||
|
|
||||||
err = fn(ormCopy)
|
err = fn(ormCopy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = tx.Rollback()
|
rollbackErr := tx.Rollback()
|
||||||
|
if rollbackErr != nil {
|
||||||
|
err = errors.Wrap(rollbackErr,
|
||||||
|
fmt.Sprintf("unable to rollback after error: %s", err.Error()),
|
||||||
|
)
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("unexpected error on kissorm: db has an invalid type")
|
return fmt.Errorf("unexpected error on kissorm: db attribute has an invalid type")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ package kissorm
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -795,6 +796,77 @@ func TestQueryChunks(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTransaction(t *testing.T) {
|
||||||
|
for _, driver := range []string{"sqlite3", "postgres"} {
|
||||||
|
t.Run(driver, func(t *testing.T) {
|
||||||
|
t.Run("should query a single row correctly", func(t *testing.T) {
|
||||||
|
err := createTable(driver)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("could not create test table!, reason:", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
db := connectDB(t, driver)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
c := newTestDB(db, driver, "users")
|
||||||
|
|
||||||
|
_ = c.Insert(ctx, &User{Name: "User1"})
|
||||||
|
_ = c.Insert(ctx, &User{Name: "User2"})
|
||||||
|
|
||||||
|
var users []User
|
||||||
|
err = c.Transaction(ctx, func(db ORMProvider) error {
|
||||||
|
db.Query(ctx, &users, "SELECT * FROM users ORDER BY id ASC")
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
|
assert.Equal(t, 2, len(users))
|
||||||
|
assert.Equal(t, "User1", users[0].Name)
|
||||||
|
assert.Equal(t, "User2", users[1].Name)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should rollback when there are errors", func(t *testing.T) {
|
||||||
|
err := createTable(driver)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("could not create test table!, reason:", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
db := connectDB(t, driver)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
c := newTestDB(db, driver, "users")
|
||||||
|
|
||||||
|
u1 := User{Name: "User1", Age: 42}
|
||||||
|
u2 := User{Name: "User2", Age: 42}
|
||||||
|
_ = c.Insert(ctx, &u1)
|
||||||
|
_ = c.Insert(ctx, &u2)
|
||||||
|
|
||||||
|
err = c.Transaction(ctx, func(db ORMProvider) error {
|
||||||
|
fmt.Printf("received db client: %#v\n", db)
|
||||||
|
err = db.Insert(ctx, &User{Name: "User3"})
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
err = db.Insert(ctx, &User{Name: "User4"})
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
err = db.Exec(ctx, "UPDATE users SET age = 22")
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
|
return errors.New("fake-error")
|
||||||
|
})
|
||||||
|
assert.NotEqual(t, nil, err)
|
||||||
|
assert.Equal(t, "fake-error", err.Error())
|
||||||
|
|
||||||
|
var users []User
|
||||||
|
err = c.Query(ctx, &users, "SELECT * FROM users ORDER BY id ASC")
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
|
assert.Equal(t, []User{u1, u2}, users)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestScanRows(t *testing.T) {
|
func TestScanRows(t *testing.T) {
|
||||||
t.Run("should scan users correctly", func(t *testing.T) {
|
t.Run("should scan users correctly", func(t *testing.T) {
|
||||||
err := createTable("sqlite3")
|
err := createTable("sqlite3")
|
||||||
|
|
9
mocks.go
9
mocks.go
|
@ -2,6 +2,8 @@ package kissorm
|
||||||
|
|
||||||
import "context"
|
import "context"
|
||||||
|
|
||||||
|
var _ ORMProvider = MockORMProvider{}
|
||||||
|
|
||||||
type MockORMProvider struct {
|
type MockORMProvider struct {
|
||||||
InsertFn func(ctx context.Context, records ...interface{}) error
|
InsertFn func(ctx context.Context, records ...interface{}) error
|
||||||
DeleteFn func(ctx context.Context, ids ...interface{}) error
|
DeleteFn func(ctx context.Context, ids ...interface{}) error
|
||||||
|
@ -11,7 +13,8 @@ type MockORMProvider struct {
|
||||||
QueryOneFn func(ctx context.Context, record interface{}, query string, params ...interface{}) error
|
QueryOneFn func(ctx context.Context, record interface{}, query string, params ...interface{}) error
|
||||||
QueryChunksFn func(ctx context.Context, parser ChunkParser) error
|
QueryChunksFn func(ctx context.Context, parser ChunkParser) error
|
||||||
|
|
||||||
ExecFn func(ctx context.Context, query string, params ...interface{}) error
|
ExecFn func(ctx context.Context, query string, params ...interface{}) error
|
||||||
|
TransactionFn func(ctx context.Context, fn func(db ORMProvider) error) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m MockORMProvider) Insert(ctx context.Context, records ...interface{}) error {
|
func (m MockORMProvider) Insert(ctx context.Context, records ...interface{}) error {
|
||||||
|
@ -41,3 +44,7 @@ func (m MockORMProvider) QueryChunks(ctx context.Context, parser ChunkParser) er
|
||||||
func (m MockORMProvider) Exec(ctx context.Context, query string, params ...interface{}) error {
|
func (m MockORMProvider) Exec(ctx context.Context, query string, params ...interface{}) error {
|
||||||
return m.ExecFn(ctx, query, params...)
|
return m.ExecFn(ctx, query, params...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m MockORMProvider) Transaction(ctx context.Context, fn func(db ORMProvider) error) error {
|
||||||
|
return m.TransactionFn(ctx, fn)
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue