mirror of https://github.com/VinGarcia/ksql.git
Improve test coverate on .Transaction()
parent
b5f2deac02
commit
06b8855621
|
@ -0,0 +1,28 @@
|
||||||
|
package ksql
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// mockTxBeginner mocks the ksql.TxBeginner interface
|
||||||
|
type mockTxBeginner struct {
|
||||||
|
DBAdapter
|
||||||
|
BeginTxFn func(ctx context.Context) (Tx, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b mockTxBeginner) BeginTx(ctx context.Context) (Tx, error) {
|
||||||
|
return b.BeginTxFn(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// mockTx mocks the ksql.Tx interface
|
||||||
|
type mockTx struct {
|
||||||
|
DBAdapter
|
||||||
|
RollbackFn func(ctx context.Context) error
|
||||||
|
CommitFn func(ctx context.Context) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockTx) Rollback(ctx context.Context) error {
|
||||||
|
return m.RollbackFn(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockTx) Commit(ctx context.Context) error {
|
||||||
|
return m.CommitFn(ctx)
|
||||||
|
}
|
|
@ -2167,15 +2167,6 @@ func QueryChunksTest(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
type MockTxBeginner struct {
|
|
||||||
DBAdapter
|
|
||||||
BeginTxFn func(ctx context.Context) (Tx, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b MockTxBeginner) BeginTx(ctx context.Context) (Tx, error) {
|
|
||||||
return b.BeginTxFn(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TransactionTest runs all tests for making sure the Transaction function is
|
// TransactionTest runs all tests for making sure the Transaction function is
|
||||||
// working for a given adapter and driver.
|
// working for a given adapter and driver.
|
||||||
func TransactionTest(
|
func TransactionTest(
|
||||||
|
@ -2286,6 +2277,79 @@ func TransactionTest(
|
||||||
tt.AssertEqual(t, users, []user{u1, u2})
|
tt.AssertEqual(t, users, []user{u1, u2})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("should rollback when the fn call panics", func(t *testing.T) {
|
||||||
|
err := createTables(driver, connStr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("could not create test table!, reason:", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
db, closer := newDBAdapter(t)
|
||||||
|
defer closer.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
c := newTestDB(db, driver)
|
||||||
|
|
||||||
|
u1 := user{Name: "User1", Age: 42}
|
||||||
|
u2 := user{Name: "User2", Age: 42}
|
||||||
|
_ = c.Insert(ctx, usersTable, &u1)
|
||||||
|
_ = c.Insert(ctx, usersTable, &u2)
|
||||||
|
|
||||||
|
panicPayload := tt.PanicHandler(func() {
|
||||||
|
c.Transaction(ctx, func(db Provider) error {
|
||||||
|
err = db.Insert(ctx, usersTable, &user{Name: "User3"})
|
||||||
|
tt.AssertNoErr(t, err)
|
||||||
|
err = db.Insert(ctx, usersTable, &user{Name: "User4"})
|
||||||
|
tt.AssertNoErr(t, err)
|
||||||
|
_, err = db.Exec(ctx, "UPDATE users SET age = 22")
|
||||||
|
tt.AssertNoErr(t, err)
|
||||||
|
|
||||||
|
panic("fakePanicPayload")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
tt.AssertEqual(t, panicPayload, "fakePanicPayload")
|
||||||
|
|
||||||
|
var users []user
|
||||||
|
err = c.Query(ctx, &users, "SELECT * FROM users ORDER BY id ASC")
|
||||||
|
tt.AssertNoErr(t, err)
|
||||||
|
|
||||||
|
tt.AssertEqual(t, users, []user{u1, u2})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should handle rollback errors when the fn call panics", func(t *testing.T) {
|
||||||
|
err := createTables(driver, connStr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("could not create test table!, reason:", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
db, closer := newDBAdapter(t)
|
||||||
|
defer closer.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
c := newTestDB(db, driver)
|
||||||
|
|
||||||
|
cMock := mockTxBeginner{
|
||||||
|
DBAdapter: c.db,
|
||||||
|
BeginTxFn: func(ctx context.Context) (Tx, error) {
|
||||||
|
return mockTx{
|
||||||
|
DBAdapter: c.db,
|
||||||
|
RollbackFn: func(ctx context.Context) error {
|
||||||
|
return fmt.Errorf("fakeRollbackErrMsg")
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
c.db = cMock
|
||||||
|
|
||||||
|
panicPayload := tt.PanicHandler(func() {
|
||||||
|
c.Transaction(ctx, func(db Provider) error {
|
||||||
|
panic("fakePanicPayload")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
err, ok := panicPayload.(error)
|
||||||
|
tt.AssertEqual(t, ok, true)
|
||||||
|
tt.AssertErrContains(t, err, "fakePanicPayload", "fakeRollbackErrMsg")
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("should report error when BeginTx() fails", func(t *testing.T) {
|
t.Run("should report error when BeginTx() fails", func(t *testing.T) {
|
||||||
db, closer := newDBAdapter(t)
|
db, closer := newDBAdapter(t)
|
||||||
defer closer.Close()
|
defer closer.Close()
|
||||||
|
@ -2293,7 +2357,7 @@ func TransactionTest(
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
c := newTestDB(db, driver)
|
c := newTestDB(db, driver)
|
||||||
|
|
||||||
cMock := MockTxBeginner{
|
cMock := mockTxBeginner{
|
||||||
DBAdapter: c.db,
|
DBAdapter: c.db,
|
||||||
BeginTxFn: func(ctx context.Context) (Tx, error) {
|
BeginTxFn: func(ctx context.Context) (Tx, error) {
|
||||||
return nil, fmt.Errorf("fakeErrMsg")
|
return nil, fmt.Errorf("fakeErrMsg")
|
||||||
|
|
Loading…
Reference in New Issue