From cea28ace2b17c5f98f4357fdbc52d20c2c603e67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sat, 19 Feb 2022 10:56:47 -0300 Subject: [PATCH] Refactor TestTransaction() so its decoupled from the adapters --- ksql_test.go | 144 ++++++++++++++++++++++++++++----------------------- 1 file changed, 79 insertions(+), 65 deletions(-) diff --git a/ksql_test.go b/ksql_test.go index 3cfd4ed..6ae62b6 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -2069,74 +2069,88 @@ func QueryChunksTest( func TestTransaction(t *testing.T) { for _, config := range supportedConfigs { - t.Run(config.driver, func(t *testing.T) { - t.Run("should query a single row correctly", func(t *testing.T) { - err := createTables(config.driver) - if err != nil { - t.Fatal("could not create test table!, reason:", err.Error()) - } - - db, closer := connectDB(t, config) - defer closer.Close() - - ctx := context.Background() - c := newTestDB(db, config.driver) - - _ = c.Insert(ctx, UsersTable, &User{Name: "User1"}) - _ = c.Insert(ctx, UsersTable, &User{Name: "User2"}) - - var users []User - err = c.Transaction(ctx, func(db Provider) error { - db.Query(ctx, &users, "SELECT * FROM users ORDER BY id ASC") - return nil - }) - assert.Equal(t, nil, err) - - assert.Equal(t, 2, len(users)) - assert.Equal(t, "User1", users[0].Name) - assert.Equal(t, "User2", users[1].Name) - }) - - t.Run("should rollback when there are errors", func(t *testing.T) { - err := createTables(config.driver) - if err != nil { - t.Fatal("could not create test table!, reason:", err.Error()) - } - - db, closer := connectDB(t, config) - defer closer.Close() - - ctx := context.Background() - c := newTestDB(db, config.driver) - - u1 := User{Name: "User1", Age: 42} - u2 := User{Name: "User2", Age: 42} - _ = c.Insert(ctx, UsersTable, &u1) - _ = c.Insert(ctx, UsersTable, &u2) - - err = c.Transaction(ctx, func(db Provider) error { - err = db.Insert(ctx, UsersTable, &User{Name: "User3"}) - assert.Equal(t, nil, err) - err = db.Insert(ctx, UsersTable, &User{Name: "User4"}) - assert.Equal(t, nil, err) - _, err = db.Exec(ctx, "UPDATE users SET age = 22") - assert.Equal(t, nil, err) - - return errors.New("fake-error") - }) - assert.NotEqual(t, nil, err) - assert.Equal(t, "fake-error", err.Error()) - - var users []User - err = c.Query(ctx, &users, "SELECT * FROM users ORDER BY id ASC") - assert.Equal(t, nil, err) - - assert.Equal(t, []User{u1, u2}, users) - }) - }) + TransactionTest(t, + config, + func(t *testing.T) (DBAdapter, io.Closer) { + db, close := connectDB(t, config) + return db, close + }, + ) } } +func TransactionTest( + t *testing.T, + config testConfig, + newDBAdapter func(t *testing.T) (DBAdapter, io.Closer), +) { + t.Run(config.driver, func(t *testing.T) { + t.Run("should query a single row correctly", func(t *testing.T) { + err := createTables(config.driver) + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) + } + + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + _ = c.Insert(ctx, UsersTable, &User{Name: "User1"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User2"}) + + var users []User + err = c.Transaction(ctx, func(db Provider) error { + db.Query(ctx, &users, "SELECT * FROM users ORDER BY id ASC") + return nil + }) + assert.Equal(t, nil, err) + + assert.Equal(t, 2, len(users)) + assert.Equal(t, "User1", users[0].Name) + assert.Equal(t, "User2", users[1].Name) + }) + + t.Run("should rollback when there are errors", func(t *testing.T) { + err := createTables(config.driver) + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) + } + + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + u1 := User{Name: "User1", Age: 42} + u2 := User{Name: "User2", Age: 42} + _ = c.Insert(ctx, UsersTable, &u1) + _ = c.Insert(ctx, UsersTable, &u2) + + err = c.Transaction(ctx, func(db Provider) error { + err = db.Insert(ctx, UsersTable, &User{Name: "User3"}) + assert.Equal(t, nil, err) + err = db.Insert(ctx, UsersTable, &User{Name: "User4"}) + assert.Equal(t, nil, err) + _, err = db.Exec(ctx, "UPDATE users SET age = 22") + assert.Equal(t, nil, err) + + return errors.New("fake-error") + }) + assert.NotEqual(t, nil, err) + assert.Equal(t, "fake-error", err.Error()) + + var users []User + err = c.Query(ctx, &users, "SELECT * FROM users ORDER BY id ASC") + assert.Equal(t, nil, err) + + assert.Equal(t, []User{u1, u2}, users) + }) + }) +} + func TestScanRows(t *testing.T) { t.Run("should scan users correctly", func(t *testing.T) { err := createTables("sqlite3")