From 45d8ef4491cc968df395a33d7a0ae5f16e022d3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Wed, 3 Aug 2022 21:05:20 -0300 Subject: [PATCH] Finishes testing all error cases in the .Transaction() method --- test_adapters.go | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/test_adapters.go b/test_adapters.go index 01581a3..2362398 100644 --- a/test_adapters.go +++ b/test_adapters.go @@ -2350,6 +2350,37 @@ func TransactionTest( tt.AssertErrContains(t, err, "fakePanicPayload", "fakeRollbackErrMsg") }) + t.Run("should handle rollback errors when fn returns an error", func(t *testing.T) { + err := createTables(driver, connStr) + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) + } + + db, closer := newDBAdapter(t) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, driver) + + cMock := mockTxBeginner{ + DBAdapter: c.db, + BeginTxFn: func(ctx context.Context) (Tx, error) { + return mockTx{ + DBAdapter: c.db, + RollbackFn: func(ctx context.Context) error { + return fmt.Errorf("fakeRollbackErrMsg") + }, + }, nil + }, + } + c.db = cMock + + err = c.Transaction(ctx, func(db Provider) error { + return fmt.Errorf("fakeTransactionErrMsg") + }) + tt.AssertErrContains(t, err, "fakeTransactionErrMsg", "fakeRollbackErrMsg") + }) + t.Run("should report error when BeginTx() fails", func(t *testing.T) { db, closer := newDBAdapter(t) defer closer.Close()