From b5f2deac0271c4c258f1895c0c465a76d8f7ba88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Wed, 3 Aug 2022 20:11:05 -0300 Subject: [PATCH] Add a few more tests to Transaction --- ksql.go | 8 +++--- test_adapters.go | 67 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 4 deletions(-) diff --git a/ksql.go b/ksql.go index e6072d5..e49e93b 100644 --- a/ksql.go +++ b/ksql.go @@ -877,14 +877,14 @@ func (c DB) Transaction(ctx context.Context, fn func(Provider) error) error { case TxBeginner: tx, err := txBeginner.BeginTx(ctx) if err != nil { - return err + return fmt.Errorf("KSQL: error starting transaction: %s", err) } defer func() { if r := recover(); r != nil { rollbackErr := tx.Rollback(ctx) if rollbackErr != nil { r = errors.Wrap(rollbackErr, - fmt.Sprintf("unable to rollback after panic with value: %v", r), + fmt.Sprintf("KSQL: unable to rollback after panic with value: %v", r), ) } panic(r) @@ -899,7 +899,7 @@ func (c DB) Transaction(ctx context.Context, fn func(Provider) error) error { rollbackErr := tx.Rollback(ctx) if rollbackErr != nil { err = errors.Wrap(rollbackErr, - fmt.Sprintf("unable to rollback after error: %s", err.Error()), + fmt.Sprintf("KSQL: unable to rollback after error: %s", err.Error()), ) } return err @@ -908,7 +908,7 @@ func (c DB) Transaction(ctx context.Context, fn func(Provider) error) error { return tx.Commit(ctx) default: - return fmt.Errorf("can't start transaction: The DBAdapter doesn't implement the TxBeginner interface") + return fmt.Errorf("KSQL: can't start transaction: The DBAdapter doesn't implement the TxBeginner interface") } } diff --git a/test_adapters.go b/test_adapters.go index 9db6dd5..e79c837 100644 --- a/test_adapters.go +++ b/test_adapters.go @@ -2167,6 +2167,15 @@ func QueryChunksTest( }) } +type MockTxBeginner struct { + DBAdapter + BeginTxFn func(ctx context.Context) (Tx, error) +} + +func (b MockTxBeginner) BeginTx(ctx context.Context) (Tx, error) { + return b.BeginTxFn(ctx) +} + // TransactionTest runs all tests for making sure the Transaction function is // working for a given adapter and driver. func TransactionTest( @@ -2203,6 +2212,43 @@ func TransactionTest( tt.AssertEqual(t, users[1].Name, "User2") }) + t.Run("should work normally in nested transactions", func(t *testing.T) { + err := createTables(driver, connStr) + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) + } + + db, closer := newDBAdapter(t) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, driver) + + u := user{ + Name: "User1", + } + err = c.Insert(ctx, usersTable, &u) + tt.AssertNoErr(t, err) + tt.AssertNotEqual(t, u.ID, 0) + + var updatedUser user + err = c.Transaction(ctx, func(db Provider) error { + u.Age = 42 + err = db.Patch(ctx, usersTable, &u) + if err != nil { + return err + } + + return db.Transaction(ctx, func(db Provider) error { + return db.QueryOne(ctx, &updatedUser, "FROM users WHERE id = "+c.dialect.Placeholder(0), u.ID) + }) + }) + tt.AssertNoErr(t, err) + + tt.AssertEqual(t, updatedUser.ID, u.ID) + tt.AssertEqual(t, updatedUser.Age, 42) + }) + t.Run("should rollback when there are errors", func(t *testing.T) { err := createTables(driver, connStr) if err != nil { @@ -2239,6 +2285,27 @@ func TransactionTest( tt.AssertEqual(t, users, []user{u1, u2}) }) + + t.Run("should report error when BeginTx() fails", func(t *testing.T) { + db, closer := newDBAdapter(t) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, driver) + + cMock := MockTxBeginner{ + DBAdapter: c.db, + BeginTxFn: func(ctx context.Context) (Tx, error) { + return nil, fmt.Errorf("fakeErrMsg") + }, + } + c.db = cMock + + err := c.Transaction(ctx, func(db Provider) error { + return nil + }) + tt.AssertErrContains(t, err, "KSQL", "fakeErrMsg") + }) }) }