diff --git a/test_adapters.go b/test_adapters.go
index 01581a3..2362398 100644
--- a/test_adapters.go
+++ b/test_adapters.go
@@ -2350,6 +2350,37 @@ func TransactionTest(
 			tt.AssertErrContains(t, err, "fakePanicPayload", "fakeRollbackErrMsg")
 		})
 
+		t.Run("should handle rollback errors when fn returns an error", func(t *testing.T) {
+			err := createTables(driver, connStr)
+			if err != nil {
+				t.Fatal("could not create test table!, reason:", err.Error())
+			}
+
+			db, closer := newDBAdapter(t)
+			defer closer.Close()
+
+			ctx := context.Background()
+			c := newTestDB(db, driver)
+
+			cMock := mockTxBeginner{
+				DBAdapter: c.db,
+				BeginTxFn: func(ctx context.Context) (Tx, error) {
+					return mockTx{
+						DBAdapter: c.db,
+						RollbackFn: func(ctx context.Context) error {
+							return fmt.Errorf("fakeRollbackErrMsg")
+						},
+					}, nil
+				},
+			}
+			c.db = cMock
+
+			err = c.Transaction(ctx, func(db Provider) error {
+				return fmt.Errorf("fakeTransactionErrMsg")
+			})
+			tt.AssertErrContains(t, err, "fakeTransactionErrMsg", "fakeRollbackErrMsg")
+		})
+
 		t.Run("should report error when BeginTx() fails", func(t *testing.T) {
 			db, closer := newDBAdapter(t)
 			defer closer.Close()