mirror of https://github.com/VinGarcia/ksql.git
Add a few more tests to Transaction
parent
eb1f85f8bb
commit
b5f2deac02
8
ksql.go
8
ksql.go
|
@ -877,14 +877,14 @@ func (c DB) Transaction(ctx context.Context, fn func(Provider) error) error {
|
||||||
case TxBeginner:
|
case TxBeginner:
|
||||||
tx, err := txBeginner.BeginTx(ctx)
|
tx, err := txBeginner.BeginTx(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("KSQL: error starting transaction: %s", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
rollbackErr := tx.Rollback(ctx)
|
rollbackErr := tx.Rollback(ctx)
|
||||||
if rollbackErr != nil {
|
if rollbackErr != nil {
|
||||||
r = errors.Wrap(rollbackErr,
|
r = errors.Wrap(rollbackErr,
|
||||||
fmt.Sprintf("unable to rollback after panic with value: %v", r),
|
fmt.Sprintf("KSQL: unable to rollback after panic with value: %v", r),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
panic(r)
|
panic(r)
|
||||||
|
@ -899,7 +899,7 @@ func (c DB) Transaction(ctx context.Context, fn func(Provider) error) error {
|
||||||
rollbackErr := tx.Rollback(ctx)
|
rollbackErr := tx.Rollback(ctx)
|
||||||
if rollbackErr != nil {
|
if rollbackErr != nil {
|
||||||
err = errors.Wrap(rollbackErr,
|
err = errors.Wrap(rollbackErr,
|
||||||
fmt.Sprintf("unable to rollback after error: %s", err.Error()),
|
fmt.Sprintf("KSQL: unable to rollback after error: %s", err.Error()),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
|
@ -908,7 +908,7 @@ func (c DB) Transaction(ctx context.Context, fn func(Provider) error) error {
|
||||||
return tx.Commit(ctx)
|
return tx.Commit(ctx)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("can't start transaction: The DBAdapter doesn't implement the TxBeginner interface")
|
return fmt.Errorf("KSQL: can't start transaction: The DBAdapter doesn't implement the TxBeginner interface")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2167,6 +2167,15 @@ func QueryChunksTest(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type MockTxBeginner struct {
|
||||||
|
DBAdapter
|
||||||
|
BeginTxFn func(ctx context.Context) (Tx, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b MockTxBeginner) BeginTx(ctx context.Context) (Tx, error) {
|
||||||
|
return b.BeginTxFn(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
// TransactionTest runs all tests for making sure the Transaction function is
|
// TransactionTest runs all tests for making sure the Transaction function is
|
||||||
// working for a given adapter and driver.
|
// working for a given adapter and driver.
|
||||||
func TransactionTest(
|
func TransactionTest(
|
||||||
|
@ -2203,6 +2212,43 @@ func TransactionTest(
|
||||||
tt.AssertEqual(t, users[1].Name, "User2")
|
tt.AssertEqual(t, users[1].Name, "User2")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("should work normally in nested transactions", func(t *testing.T) {
|
||||||
|
err := createTables(driver, connStr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("could not create test table!, reason:", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
db, closer := newDBAdapter(t)
|
||||||
|
defer closer.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
c := newTestDB(db, driver)
|
||||||
|
|
||||||
|
u := user{
|
||||||
|
Name: "User1",
|
||||||
|
}
|
||||||
|
err = c.Insert(ctx, usersTable, &u)
|
||||||
|
tt.AssertNoErr(t, err)
|
||||||
|
tt.AssertNotEqual(t, u.ID, 0)
|
||||||
|
|
||||||
|
var updatedUser user
|
||||||
|
err = c.Transaction(ctx, func(db Provider) error {
|
||||||
|
u.Age = 42
|
||||||
|
err = db.Patch(ctx, usersTable, &u)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return db.Transaction(ctx, func(db Provider) error {
|
||||||
|
return db.QueryOne(ctx, &updatedUser, "FROM users WHERE id = "+c.dialect.Placeholder(0), u.ID)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
tt.AssertNoErr(t, err)
|
||||||
|
|
||||||
|
tt.AssertEqual(t, updatedUser.ID, u.ID)
|
||||||
|
tt.AssertEqual(t, updatedUser.Age, 42)
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("should rollback when there are errors", func(t *testing.T) {
|
t.Run("should rollback when there are errors", func(t *testing.T) {
|
||||||
err := createTables(driver, connStr)
|
err := createTables(driver, connStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -2239,6 +2285,27 @@ func TransactionTest(
|
||||||
|
|
||||||
tt.AssertEqual(t, users, []user{u1, u2})
|
tt.AssertEqual(t, users, []user{u1, u2})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("should report error when BeginTx() fails", func(t *testing.T) {
|
||||||
|
db, closer := newDBAdapter(t)
|
||||||
|
defer closer.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
c := newTestDB(db, driver)
|
||||||
|
|
||||||
|
cMock := MockTxBeginner{
|
||||||
|
DBAdapter: c.db,
|
||||||
|
BeginTxFn: func(ctx context.Context) (Tx, error) {
|
||||||
|
return nil, fmt.Errorf("fakeErrMsg")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
c.db = cMock
|
||||||
|
|
||||||
|
err := c.Transaction(ctx, func(db Provider) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
tt.AssertErrContains(t, err, "KSQL", "fakeErrMsg")
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue