diff --git a/ksql_test.go b/ksql_test.go index 081eb2c..7b63c7b 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -1346,185 +1346,199 @@ func DeleteTest( func TestUpdate(t *testing.T) { for _, config := range supportedConfigs { - t.Run(config.driver, func(t *testing.T) { - err := createTables(config.driver) - if err != nil { - t.Fatal("could not create test table!, reason:", err.Error()) + UpdateTest(t, + config, + func(t *testing.T) (DBAdapter, io.Closer) { + db, close := connectDB(t, config) + return db, close + }, + ) + } +} + +func UpdateTest( + t *testing.T, + config testConfig, + newDBAdapter func(t *testing.T) (DBAdapter, io.Closer), +) { + t.Run(config.driver, func(t *testing.T) { + err := createTables(config.driver) + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) + } + + t.Run("should update one User{} correctly", func(t *testing.T) { + db, closer := newDBAdapter(t) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + u := User{ + Name: "Letícia", + } + _, err := db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Letícia', 0)`) + assert.Equal(t, nil, err) + + err = getUserByName(db, config.driver, &u, "Letícia") + assert.Equal(t, nil, err) + assert.NotEqual(t, uint(0), u.ID) + + err = c.Update(ctx, UsersTable, User{ + ID: u.ID, + Name: "Thayane", + }) + assert.Equal(t, nil, err) + + var result User + err = getUserByID(c.db, c.dialect, &result, u.ID) + assert.Equal(t, nil, err) + assert.Equal(t, "Thayane", result.Name) + }) + + t.Run("should update one &User{} correctly", func(t *testing.T) { + db, closer := newDBAdapter(t) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + u := User{ + Name: "Letícia", + } + _, err := db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Letícia', 0)`) + assert.Equal(t, nil, err) + + err = getUserByName(db, config.driver, &u, "Letícia") + assert.Equal(t, nil, err) + assert.NotEqual(t, uint(0), u.ID) + + err = c.Update(ctx, UsersTable, &User{ + ID: u.ID, + Name: "Thayane", + }) + assert.Equal(t, nil, err) + + var result User + err = getUserByID(c.db, c.dialect, &result, u.ID) + assert.Equal(t, nil, err) + assert.Equal(t, "Thayane", result.Name) + }) + + t.Run("should ignore null pointers on partial updates", func(t *testing.T) { + db, closer := newDBAdapter(t) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + type partialUser struct { + ID uint `ksql:"id"` + Name string `ksql:"name"` + Age *int `ksql:"age"` } - t.Run("should update one User{} correctly", func(t *testing.T) { - db, closer := connectDB(t, config) - defer closer.Close() + _, err := db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Letícia', 22)`) + assert.Equal(t, nil, err) - ctx := context.Background() - c := newTestDB(db, config.driver) + var u User + err = getUserByName(db, config.driver, &u, "Letícia") + assert.Equal(t, nil, err) + assert.NotEqual(t, uint(0), u.ID) - u := User{ - Name: "Letícia", - } - _, err := db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Letícia', 0)`) - assert.Equal(t, nil, err) - - err = getUserByName(db, config.driver, &u, "Letícia") - assert.Equal(t, nil, err) - assert.NotEqual(t, uint(0), u.ID) - - err = c.Update(ctx, UsersTable, User{ - ID: u.ID, - Name: "Thayane", - }) - assert.Equal(t, nil, err) - - var result User - err = getUserByID(c.db, c.dialect, &result, u.ID) - assert.Equal(t, nil, err) - assert.Equal(t, "Thayane", result.Name) + err = c.Update(ctx, UsersTable, partialUser{ + ID: u.ID, + // Should be updated because it is not null, just empty: + Name: "", + // Should not be updated because it is null: + Age: nil, }) + assert.Equal(t, nil, err) - t.Run("should update one &User{} correctly", func(t *testing.T) { - db, closer := connectDB(t, config) - defer closer.Close() - - ctx := context.Background() - c := newTestDB(db, config.driver) - - u := User{ - Name: "Letícia", - } - _, err := db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Letícia', 0)`) - assert.Equal(t, nil, err) - - err = getUserByName(db, config.driver, &u, "Letícia") - assert.Equal(t, nil, err) - assert.NotEqual(t, uint(0), u.ID) - - err = c.Update(ctx, UsersTable, &User{ - ID: u.ID, - Name: "Thayane", - }) - assert.Equal(t, nil, err) - - var result User - err = getUserByID(c.db, c.dialect, &result, u.ID) - assert.Equal(t, nil, err) - assert.Equal(t, "Thayane", result.Name) - }) - - t.Run("should ignore null pointers on partial updates", func(t *testing.T) { - db, closer := connectDB(t, config) - defer closer.Close() - - ctx := context.Background() - c := newTestDB(db, config.driver) - - type partialUser struct { - ID uint `ksql:"id"` - Name string `ksql:"name"` - Age *int `ksql:"age"` - } - - _, err := db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Letícia', 22)`) - assert.Equal(t, nil, err) - - var u User - err = getUserByName(db, config.driver, &u, "Letícia") - assert.Equal(t, nil, err) - assert.NotEqual(t, uint(0), u.ID) - - err = c.Update(ctx, UsersTable, partialUser{ - ID: u.ID, - // Should be updated because it is not null, just empty: - Name: "", - // Should not be updated because it is null: - Age: nil, - }) - assert.Equal(t, nil, err) - - var result User - err = getUserByID(c.db, c.dialect, &result, u.ID) - assert.Equal(t, nil, err) - assert.Equal(t, "", result.Name) - assert.Equal(t, 22, result.Age) - }) - - t.Run("should update valid pointers on partial updates", func(t *testing.T) { - db, closer := connectDB(t, config) - defer closer.Close() - - ctx := context.Background() - c := newTestDB(db, config.driver) - - type partialUser struct { - ID uint `ksql:"id"` - Name string `ksql:"name"` - Age *int `ksql:"age"` - } - - _, err := db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Letícia', 22)`) - assert.Equal(t, nil, err) - - var u User - err = getUserByName(db, config.driver, &u, "Letícia") - assert.Equal(t, nil, err) - assert.NotEqual(t, uint(0), u.ID) - - // Should update all fields: - err = c.Update(ctx, UsersTable, partialUser{ - ID: u.ID, - Name: "Thay", - Age: nullable.Int(42), - }) - assert.Equal(t, nil, err) - - var result User - err = getUserByID(c.db, c.dialect, &result, u.ID) - assert.Equal(t, nil, err) - - assert.Equal(t, "Thay", result.Name) - assert.Equal(t, 42, result.Age) - }) - - t.Run("should return ErrRecordNotFound when asked to update an inexistent user", func(t *testing.T) { - db, closer := connectDB(t, config) - defer closer.Close() - - ctx := context.Background() - c := newTestDB(db, config.driver) - - err = c.Update(ctx, UsersTable, User{ - ID: 4200, - Name: "Thayane", - }) - assert.Equal(t, ErrRecordNotFound, err) - }) - - t.Run("should report database errors correctly", func(t *testing.T) { - db, closer := connectDB(t, config) - defer closer.Close() - - ctx := context.Background() - c := newTestDB(db, config.driver) - - err = c.Update(ctx, NewTable("non_existing_table"), User{ - ID: 1, - Name: "Thayane", - }) - assert.NotEqual(t, nil, err) - }) - - t.Run("should report error if it receives a nil pointer to a struct", func(t *testing.T) { - db, closer := connectDB(t, config) - defer closer.Close() - - ctx := context.Background() - c := newTestDB(db, config.driver) - - var user *User - err := c.Update(ctx, UsersTable, user) - assert.NotEqual(t, nil, err) - }) + var result User + err = getUserByID(c.db, c.dialect, &result, u.ID) + assert.Equal(t, nil, err) + assert.Equal(t, "", result.Name) + assert.Equal(t, 22, result.Age) }) - } + + t.Run("should update valid pointers on partial updates", func(t *testing.T) { + db, closer := newDBAdapter(t) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + type partialUser struct { + ID uint `ksql:"id"` + Name string `ksql:"name"` + Age *int `ksql:"age"` + } + + _, err := db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Letícia', 22)`) + assert.Equal(t, nil, err) + + var u User + err = getUserByName(db, config.driver, &u, "Letícia") + assert.Equal(t, nil, err) + assert.NotEqual(t, uint(0), u.ID) + + // Should update all fields: + err = c.Update(ctx, UsersTable, partialUser{ + ID: u.ID, + Name: "Thay", + Age: nullable.Int(42), + }) + assert.Equal(t, nil, err) + + var result User + err = getUserByID(c.db, c.dialect, &result, u.ID) + assert.Equal(t, nil, err) + + assert.Equal(t, "Thay", result.Name) + assert.Equal(t, 42, result.Age) + }) + + t.Run("should return ErrRecordNotFound when asked to update an inexistent user", func(t *testing.T) { + db, closer := newDBAdapter(t) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + err = c.Update(ctx, UsersTable, User{ + ID: 4200, + Name: "Thayane", + }) + assert.Equal(t, ErrRecordNotFound, err) + }) + + t.Run("should report database errors correctly", func(t *testing.T) { + db, closer := newDBAdapter(t) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + err = c.Update(ctx, NewTable("non_existing_table"), User{ + ID: 1, + Name: "Thayane", + }) + assert.NotEqual(t, nil, err) + }) + + t.Run("should report error if it receives a nil pointer to a struct", func(t *testing.T) { + db, closer := newDBAdapter(t) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + var user *User + err := c.Update(ctx, UsersTable, user) + assert.NotEqual(t, nil, err) + }) + }) } func TestQueryChunks(t *testing.T) { @@ -2404,9 +2418,14 @@ func newTestDB(db DBAdapter, driver string) DB { } } -type NopCloser struct{} +type CloserAdapter struct { + close func() +} -func (NopCloser) Close() error { return nil } +func (c CloserAdapter) Close() error { + c.close() + return nil +} func connectDB(t *testing.T, config testConfig) (DBAdapter, io.Closer) { connStr := connectionString[config.driver] @@ -2426,7 +2445,7 @@ func connectDB(t *testing.T, config testConfig) (DBAdapter, io.Closer) { if err != nil { t.Fatal(err.Error()) } - return PGXAdapter{pool}, NopCloser{} + return PGXAdapter{pool}, CloserAdapter{close: pool.Close} } t.Fatalf("unsupported adapter: %s", config.adapterName)