package kisssql import ( "context" "database/sql" "encoding/json" "errors" "fmt" "strings" "testing" "github.com/ditointernet/go-assert" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" "github.com/vingarcia/kisssql/nullable" ) type User struct { ID uint `kisssql:"id"` Name string `kisssql:"name"` Age int `kisssql:"age"` Address Address `kisssql:"address,json"` } type Address struct { Street string `json:"street"` Number string `json:"number"` City string `json:"city"` State string `json:"state"` Country string `json:"country"` } func TestQuery(t *testing.T) { for _, driver := range []string{"sqlite3", "postgres"} { t.Run(driver, func(t *testing.T) { t.Run("using slice of structs", func(t *testing.T) { err := createTable(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } t.Run("should return 0 results correctly", func(t *testing.T) { db := connectDB(t, driver) defer db.Close() ctx := context.Background() c := newTestDB(db, driver, "users") var users []User err := c.Query(ctx, &users, `SELECT * FROM users WHERE id=1;`) assert.Equal(t, nil, err) assert.Equal(t, []User(nil), users) users = []User{} err = c.Query(ctx, &users, `SELECT * FROM users WHERE id=1;`) assert.Equal(t, nil, err) assert.Equal(t, []User{}, users) }) t.Run("should return a user correctly", func(t *testing.T) { db := connectDB(t, driver) defer db.Close() _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia', 0, '{"country":"BR"}')`) assert.Equal(t, nil, err) ctx := context.Background() c := newTestDB(db, driver, "users") var users []User err = c.Query(ctx, &users, `SELECT * FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia") assert.Equal(t, nil, err) assert.Equal(t, 1, len(users)) assert.NotEqual(t, uint(0), users[0].ID) assert.Equal(t, "Bia", users[0].Name) assert.Equal(t, "BR", users[0].Address.Country) }) t.Run("should return multiple users correctly", func(t *testing.T) { db := connectDB(t, driver) defer db.Close() _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('João Garcia', 0, '{"country":"US"}')`) assert.Equal(t, nil, err) _, err = db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia Garcia', 0, '{"country":"BR"}')`) assert.Equal(t, nil, err) ctx := context.Background() c := newTestDB(db, driver, "users") var users []User err = c.Query(ctx, &users, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Garcia") assert.Equal(t, nil, err) assert.Equal(t, 2, len(users)) assert.NotEqual(t, uint(0), users[0].ID) assert.Equal(t, "João Garcia", users[0].Name) assert.Equal(t, "US", users[0].Address.Country) assert.NotEqual(t, uint(0), users[1].ID) assert.Equal(t, "Bia Garcia", users[1].Name) assert.Equal(t, "BR", users[1].Address.Country) }) }) t.Run("using slice of pointers to structs", func(t *testing.T) { err := createTable(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } t.Run("should return 0 results correctly", func(t *testing.T) { db := connectDB(t, driver) defer db.Close() ctx := context.Background() c := newTestDB(db, driver, "users") var users []*User err := c.Query(ctx, &users, `SELECT * FROM users WHERE id=1;`) assert.Equal(t, nil, err) assert.Equal(t, []*User(nil), users) users = []*User{} err = c.Query(ctx, &users, `SELECT * FROM users WHERE id=1;`) assert.Equal(t, nil, err) assert.Equal(t, []*User{}, users) }) t.Run("should return a user correctly", func(t *testing.T) { db := connectDB(t, driver) defer db.Close() _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia', 0, '{"country":"BR"}')`) assert.Equal(t, nil, err) ctx := context.Background() c := newTestDB(db, driver, "users") var users []*User err = c.Query(ctx, &users, `SELECT * FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia") assert.Equal(t, nil, err) assert.Equal(t, 1, len(users)) assert.Equal(t, "Bia", users[0].Name) assert.NotEqual(t, uint(0), users[0].ID) }) t.Run("should return multiple users correctly", func(t *testing.T) { db := connectDB(t, driver) defer db.Close() _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('João Garcia', 0, '{"country":"US"}')`) assert.Equal(t, nil, err) _, err = db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia Garcia', 0, '{"country":"BR"}')`) assert.Equal(t, nil, err) ctx := context.Background() c := newTestDB(db, driver, "users") var users []*User err = c.Query(ctx, &users, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Garcia") assert.Equal(t, nil, err) assert.Equal(t, 2, len(users)) assert.Equal(t, "João Garcia", users[0].Name) assert.NotEqual(t, uint(0), users[0].ID) assert.Equal(t, "US", users[0].Address.Country) assert.Equal(t, "Bia Garcia", users[1].Name) assert.NotEqual(t, uint(0), users[1].ID) assert.Equal(t, "BR", users[1].Address.Country) }) }) t.Run("testing error cases", func(t *testing.T) { err := createTable(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } t.Run("should report error if input is not a pointer to a slice of structs", func(t *testing.T) { db := connectDB(t, driver) defer db.Close() _, err := db.Exec(`INSERT INTO users (name, age) VALUES ('Andréa Sá', 0)`) assert.Equal(t, nil, err) _, err = db.Exec(`INSERT INTO users (name, age) VALUES ('Caio Sá', 0)`) assert.Equal(t, nil, err) ctx := context.Background() c := newTestDB(db, driver, "users") err = c.Query(ctx, &User{}, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Sá") assert.NotEqual(t, nil, err) err = c.Query(ctx, []*User{}, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Sá") assert.NotEqual(t, nil, err) var i int err = c.Query(ctx, &i, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Sá") assert.NotEqual(t, nil, err) err = c.Query(ctx, &[]int{}, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Sá") assert.NotEqual(t, nil, err) }) t.Run("should report error if the query is not valid", func(t *testing.T) { db := connectDB(t, driver) defer db.Close() ctx := context.Background() c := newTestDB(db, driver, "users") var users []User err = c.Query(ctx, &users, `SELECT * FROM not a valid query`) assert.NotEqual(t, nil, err) }) }) }) } } func TestQueryOne(t *testing.T) { for _, driver := range []string{"sqlite3", "postgres"} { t.Run(driver, func(t *testing.T) { err := createTable(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } t.Run("should return RecordNotFoundErr when there are no results", func(t *testing.T) { db := connectDB(t, driver) defer db.Close() ctx := context.Background() c := newTestDB(db, driver, "users") u := User{} err := c.QueryOne(ctx, &u, `SELECT * FROM users WHERE id=1;`) assert.Equal(t, ErrRecordNotFound, err) }) t.Run("should return a user correctly", func(t *testing.T) { db := connectDB(t, driver) defer db.Close() _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia', 0, '{"country":"BR"}')`) assert.Equal(t, nil, err) ctx := context.Background() c := newTestDB(db, driver, "users") u := User{} err = c.QueryOne(ctx, &u, `SELECT * FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia") assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u.ID) assert.Equal(t, "Bia", u.Name) assert.Equal(t, Address{ Country: "BR", }, u.Address) }) t.Run("should return only the first result on multiples matches", func(t *testing.T) { db := connectDB(t, driver) defer db.Close() _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('Andréa Sá', 0, '{"country":"US"}')`) assert.Equal(t, nil, err) _, err = db.Exec(`INSERT INTO users (name, age, address) VALUES ('Caio Sá', 0, '{"country":"BR"}')`) assert.Equal(t, nil, err) ctx := context.Background() c := newTestDB(db, driver, "users") var u User err = c.QueryOne(ctx, &u, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0)+` ORDER BY id ASC`, "% Sá") assert.Equal(t, nil, err) assert.Equal(t, "Andréa Sá", u.Name) assert.Equal(t, 0, u.Age) assert.Equal(t, Address{ Country: "US", }, u.Address) }) t.Run("should report error if input is not a pointer to struct", func(t *testing.T) { db := connectDB(t, driver) defer db.Close() _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('Andréa Sá', 0, '{"country":"US"}')`) assert.Equal(t, nil, err) _, err = db.Exec(`INSERT INTO users (name, age, address) VALUES ('Caio Sá', 0, '{"country":"BR"}')`) assert.Equal(t, nil, err) ctx := context.Background() c := newTestDB(db, driver, "users") err = c.QueryOne(ctx, &[]User{}, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Sá") assert.NotEqual(t, nil, err) err = c.QueryOne(ctx, User{}, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Sá") assert.NotEqual(t, nil, err) }) t.Run("should report error if the query is not valid", func(t *testing.T) { db := connectDB(t, driver) defer db.Close() ctx := context.Background() c := newTestDB(db, driver, "users") var user User err = c.QueryOne(ctx, &user, `SELECT * FROM not a valid query`) assert.NotEqual(t, nil, err) }) }) } } func TestInsert(t *testing.T) { for _, driver := range []string{"sqlite3", "postgres"} { t.Run(driver, func(t *testing.T) { t.Run("using slice of structs", func(t *testing.T) { err := createTable(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } t.Run("should ignore empty lists of users", func(t *testing.T) { db := connectDB(t, driver) defer db.Close() ctx := context.Background() c := newTestDB(db, driver, "users") err = c.Insert(ctx) assert.Equal(t, nil, err) }) t.Run("should insert one user correctly", func(t *testing.T) { db := connectDB(t, driver) defer db.Close() ctx := context.Background() c := newTestDB(db, driver, "users") u := User{ Name: "Fernanda", Address: Address{ Country: "Brazil", }, } err := c.Insert(ctx, &u) assert.Equal(t, nil, err) assert.NotEqual(t, 0, u.ID) result := User{} err = getUserByID(c.db, c.dialect, &result, u.ID) assert.Equal(t, nil, err) assert.Equal(t, u.Name, result.Name) assert.Equal(t, u.Address, result.Address) }) t.Run("should insert ignoring the ID for sqlite and multiple ids", func(t *testing.T) { if driver != "sqlite3" { return } db := connectDB(t, driver) defer db.Close() ctx := context.Background() // Using columns "id" and "name" as IDs: c, err := New(driver, connectionString[driver], Config{ TableName: "users", IDColumns: []string{"id", "name"}, }) assert.Equal(t, nil, err) u := User{ Name: "No ID returned", Age: 3434, // Random number to avoid false positives on this test Address: Address{ Country: "Brazil 3434", }, } err = c.Insert(ctx, &u) assert.Equal(t, nil, err) assert.Equal(t, uint(0), u.ID) result := User{} err = getUserByName(c.db, c.dialect, &result, "No ID returned") assert.Equal(t, nil, err) assert.Equal(t, u.Age, result.Age) assert.Equal(t, u.Address, result.Address) }) }) t.Run("testing error cases", func(t *testing.T) { err := createTable(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } t.Run("should report error for invalid input types", func(t *testing.T) { db := connectDB(t, driver) defer db.Close() ctx := context.Background() c := newTestDB(db, driver, "users") err = c.Insert(ctx, "foo", "bar") assert.NotEqual(t, nil, err) err = c.Insert(ctx, nullable.String("foo"), nullable.String("bar")) assert.NotEqual(t, nil, err) err = c.Insert(ctx, map[string]interface{}{ "name": "foo", "age": 12, }) assert.NotEqual(t, nil, err) ifUserForgetToExpandList := []interface{}{ &User{Name: "foo", Age: 22}, &User{Name: "bar", Age: 32}, } err = c.Insert(ctx, ifUserForgetToExpandList) assert.NotEqual(t, nil, err) // We might want to support this in the future, but not for now: err = c.Insert(ctx, User{Name: "not a ptr to user", Age: 42}) assert.NotEqual(t, nil, err) }) t.Run("should report error if for some reason the insertMethod is invalid", func(t *testing.T) { db := connectDB(t, driver) defer db.Close() ctx := context.Background() c := newTestDB(db, driver, "users") // This is an invalid value: c.insertMethod = insertMethod(42) err = c.Insert(ctx, &User{Name: "foo"}) assert.NotEqual(t, nil, err) }) }) }) } } func TestDelete(t *testing.T) { for _, driver := range []string{"sqlite3", "postgres"} { t.Run(driver, func(t *testing.T) { err := createTable(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } t.Run("should ignore empty lists of ids", func(t *testing.T) { db := connectDB(t, driver) defer db.Close() ctx := context.Background() c := newTestDB(db, driver, "users") u := User{ Name: "Won't be deleted", } err := c.Insert(ctx, &u) assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u.ID) result := User{} err = getUserByID(c.db, c.dialect, &result, u.ID) assert.Equal(t, nil, err) assert.Equal(t, u.ID, result.ID) err = c.Delete(ctx) assert.Equal(t, nil, err) result = User{} err = getUserByID(c.db, c.dialect, &result, u.ID) assert.Equal(t, nil, err) assert.Equal(t, u.ID, result.ID) }) t.Run("should delete one id correctly", func(t *testing.T) { db := connectDB(t, driver) defer db.Close() ctx := context.Background() c := newTestDB(db, driver, "users") u1 := User{ Name: "Fernanda", } err := c.Insert(ctx, &u1) assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u1.ID) result := User{} err = getUserByID(c.db, c.dialect, &result, u1.ID) assert.Equal(t, nil, err) assert.Equal(t, u1.ID, result.ID) u2 := User{ Name: "Won't be deleted", } err = c.Insert(ctx, &u2) assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u2.ID) result = User{} err = getUserByID(c.db, c.dialect, &result, u2.ID) assert.Equal(t, nil, err) assert.Equal(t, u2.ID, result.ID) err = c.Delete(ctx, u1.ID) assert.Equal(t, nil, err) result = User{} err = getUserByID(c.db, c.dialect, &result, u1.ID) assert.Equal(t, sql.ErrNoRows, err) result = User{} err = getUserByID(c.db, c.dialect, &result, u2.ID) assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), result.ID) assert.Equal(t, "Won't be deleted", result.Name) }) t.Run("should delete multiple ids correctly", func(t *testing.T) { db := connectDB(t, driver) defer db.Close() ctx := context.Background() c := newTestDB(db, driver, "users") u1 := User{ Name: "Fernanda", } err := c.Insert(ctx, &u1) assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u1.ID) u2 := User{ Name: "Juliano", } err = c.Insert(ctx, &u2) assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u2.ID) u3 := User{ Name: "This won't be deleted", } err = c.Insert(ctx, &u3) assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u3.ID) result := User{} err = getUserByID(c.db, c.dialect, &result, u1.ID) assert.Equal(t, nil, err) assert.Equal(t, u1.ID, result.ID) result = User{} err = getUserByID(c.db, c.dialect, &result, u2.ID) assert.Equal(t, nil, err) assert.Equal(t, u2.ID, result.ID) result = User{} err = getUserByID(c.db, c.dialect, &result, u3.ID) assert.Equal(t, nil, err) assert.Equal(t, u3.ID, result.ID) err = c.Delete(ctx, u1.ID, u2.ID) assert.Equal(t, nil, err) results := []User{} err = getUsersByID(c.db, c.dialect, &results, u1.ID, u2.ID, u3.ID) assert.Equal(t, nil, err) assert.Equal(t, 1, len(results)) assert.Equal(t, "This won't be deleted", results[0].Name) }) }) } } func TestUpdate(t *testing.T) { for _, driver := range []string{"sqlite3", "postgres"} { t.Run(driver, func(t *testing.T) { err := createTable(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } t.Run("should ignore empty lists of ids", func(t *testing.T) { db := connectDB(t, driver) defer db.Close() ctx := context.Background() c := newTestDB(db, driver, "users") u := User{ Name: "Thay", } err := c.Insert(ctx, &u) assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u.ID) // Empty update, should do nothing: err = c.Update(ctx) assert.Equal(t, nil, err) result := User{} err = getUserByID(c.db, c.dialect, &result, u.ID) assert.Equal(t, nil, err) assert.Equal(t, "Thay", result.Name) }) t.Run("should update one user correctly", func(t *testing.T) { db := connectDB(t, driver) defer db.Close() ctx := context.Background() c := newTestDB(db, driver, "users") u := User{ Name: "Letícia", } _, err := db.Exec(`INSERT INTO users (name, age) VALUES ('Letícia', 0)`) assert.Equal(t, nil, err) row := db.QueryRow(`SELECT id FROM users WHERE name = 'Letícia'`) assert.Equal(t, nil, row.Err()) err = row.Scan(&u.ID) assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u.ID) err = c.Update(ctx, User{ ID: u.ID, Name: "Thayane", }) assert.Equal(t, nil, err) var result User err = getUserByID(c.db, c.dialect, &result, u.ID) assert.Equal(t, nil, err) assert.Equal(t, "Thayane", result.Name) }) t.Run("should update one user correctly", func(t *testing.T) { db := connectDB(t, driver) defer db.Close() ctx := context.Background() c := newTestDB(db, driver, "users") u := User{ Name: "Letícia", } _, err := db.Exec(`INSERT INTO users (name, age) VALUES ('Letícia', 0)`) assert.Equal(t, nil, err) row := db.QueryRow(`SELECT id FROM users WHERE name = 'Letícia'`) assert.Equal(t, nil, row.Err()) err = row.Scan(&u.ID) assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u.ID) err = c.Update(ctx, User{ ID: u.ID, Name: "Thayane", }) assert.Equal(t, nil, err) var result User err = getUserByID(c.db, c.dialect, &result, u.ID) assert.Equal(t, nil, err) assert.Equal(t, "Thayane", result.Name) }) t.Run("should ignore null pointers on partial updates", func(t *testing.T) { db := connectDB(t, driver) defer db.Close() ctx := context.Background() c := newTestDB(db, driver, "users") type partialUser struct { ID uint `kisssql:"id"` Name string `kisssql:"name"` Age *int `kisssql:"age"` } u := partialUser{ Name: "Letícia", Age: nullable.Int(22), } _, err := db.Exec(`INSERT INTO users (name, age) VALUES ('Letícia', 22)`) assert.Equal(t, nil, err) row := db.QueryRow(`SELECT id FROM users WHERE name = 'Letícia'`) assert.Equal(t, nil, row.Err()) err = row.Scan(&u.ID) assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u.ID) err = c.Update(ctx, partialUser{ ID: u.ID, // Should be updated because it is not null, just empty: Name: "", // Should not be updated because it is null: Age: nil, }) assert.Equal(t, nil, err) var result User err = getUserByID(c.db, c.dialect, &result, u.ID) assert.Equal(t, nil, err) assert.Equal(t, "", result.Name) assert.Equal(t, 22, result.Age) }) t.Run("should update valid pointers on partial updates", func(t *testing.T) { db := connectDB(t, driver) defer db.Close() ctx := context.Background() c := newTestDB(db, driver, "users") type partialUser struct { ID uint `kisssql:"id"` Name string `kisssql:"name"` Age *int `kisssql:"age"` } u := partialUser{ Name: "Letícia", Age: nullable.Int(22), } _, err := db.Exec(`INSERT INTO users (name, age) VALUES ('Letícia', 22)`) assert.Equal(t, nil, err) row := db.QueryRow(`SELECT id FROM users WHERE name = 'Letícia'`) assert.Equal(t, nil, row.Err()) err = row.Scan(&u.ID) assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u.ID) // Should update all fields: err = c.Update(ctx, partialUser{ ID: u.ID, Name: "Thay", Age: nullable.Int(42), }) assert.Equal(t, nil, err) var result User err = getUserByID(c.db, c.dialect, &result, u.ID) assert.Equal(t, nil, err) assert.Equal(t, "Thay", result.Name) assert.Equal(t, 42, result.Age) }) t.Run("should report database errors correctly", func(t *testing.T) { db := connectDB(t, driver) defer db.Close() ctx := context.Background() c := newTestDB(db, driver, "non_existing_table") err = c.Update(ctx, User{ ID: 1, Name: "Thayane", }) assert.NotEqual(t, nil, err) }) }) } } func TestQueryChunks(t *testing.T) { for _, driver := range []string{"sqlite3", "postgres"} { t.Run(driver, func(t *testing.T) { t.Run("should query a single row correctly", func(t *testing.T) { err := createTable(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } db := connectDB(t, driver) defer db.Close() ctx := context.Background() c := newTestDB(db, driver, "users") _ = c.Insert(ctx, &User{ Name: "User1", Address: Address{Country: "BR"}, }) var length int var u User err = c.QueryChunks(ctx, ChunkParser{ Query: `SELECT * FROM users WHERE name = ` + c.dialect.Placeholder(0), Params: []interface{}{"User1"}, ChunkSize: 100, ForEachChunk: func(users []User) error { length = len(users) if length > 0 { u = users[0] } return nil }, }) assert.Equal(t, nil, err) assert.Equal(t, 1, length) assert.NotEqual(t, uint(0), u.ID) assert.Equal(t, "User1", u.Name) assert.Equal(t, "BR", u.Address.Country) }) t.Run("should query one chunk correctly", func(t *testing.T) { err := createTable(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } db := connectDB(t, driver) defer db.Close() ctx := context.Background() c := newTestDB(db, driver, "users") _ = c.Insert(ctx, &User{Name: "User1", Address: Address{Country: "US"}}) _ = c.Insert(ctx, &User{Name: "User2", Address: Address{Country: "BR"}}) var lengths []int var users []User err = c.QueryChunks(ctx, ChunkParser{ Query: `select * from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, Params: []interface{}{"User%"}, ChunkSize: 2, ForEachChunk: func(buffer []User) error { users = append(users, buffer...) lengths = append(lengths, len(buffer)) return nil }, }) assert.Equal(t, nil, err) assert.Equal(t, 1, len(lengths)) assert.Equal(t, 2, lengths[0]) assert.NotEqual(t, uint(0), users[0].ID) assert.Equal(t, "User1", users[0].Name) assert.Equal(t, "US", users[0].Address.Country) assert.NotEqual(t, uint(0), users[1].ID) assert.Equal(t, "User2", users[1].Name) assert.Equal(t, "BR", users[1].Address.Country) }) t.Run("should query chunks of 1 correctly", func(t *testing.T) { err := createTable(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } db := connectDB(t, driver) defer db.Close() ctx := context.Background() c := newTestDB(db, driver, "users") _ = c.Insert(ctx, &User{Name: "User1", Address: Address{Country: "US"}}) _ = c.Insert(ctx, &User{Name: "User2", Address: Address{Country: "BR"}}) var lengths []int var users []User err = c.QueryChunks(ctx, ChunkParser{ Query: `select * from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, Params: []interface{}{"User%"}, ChunkSize: 1, ForEachChunk: func(buffer []User) error { lengths = append(lengths, len(buffer)) users = append(users, buffer...) return nil }, }) assert.Equal(t, nil, err) assert.Equal(t, 2, len(users)) assert.Equal(t, []int{1, 1}, lengths) assert.NotEqual(t, uint(0), users[0].ID) assert.Equal(t, "User1", users[0].Name) assert.Equal(t, "US", users[0].Address.Country) assert.NotEqual(t, uint(0), users[1].ID) assert.Equal(t, "User2", users[1].Name) assert.Equal(t, "BR", users[1].Address.Country) }) t.Run("should load partially filled chunks correctly", func(t *testing.T) { err := createTable(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } db := connectDB(t, driver) defer db.Close() ctx := context.Background() c := newTestDB(db, driver, "users") _ = c.Insert(ctx, &User{Name: "User1"}) _ = c.Insert(ctx, &User{Name: "User2"}) _ = c.Insert(ctx, &User{Name: "User3"}) var lengths []int var users []User err = c.QueryChunks(ctx, ChunkParser{ Query: `select * from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, Params: []interface{}{"User%"}, ChunkSize: 2, ForEachChunk: func(buffer []User) error { lengths = append(lengths, len(buffer)) users = append(users, buffer...) return nil }, }) assert.Equal(t, nil, err) assert.Equal(t, 3, len(users)) assert.NotEqual(t, uint(0), users[0].ID) assert.Equal(t, "User1", users[0].Name) assert.NotEqual(t, uint(0), users[1].ID) assert.Equal(t, "User2", users[1].Name) assert.NotEqual(t, uint(0), users[2].ID) assert.Equal(t, "User3", users[2].Name) assert.Equal(t, []int{2, 1}, lengths) }) t.Run("should abort the first iteration when the callback returns an ErrAbortIteration", func(t *testing.T) { err := createTable(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } db := connectDB(t, driver) defer db.Close() ctx := context.Background() c := newTestDB(db, driver, "users") _ = c.Insert(ctx, &User{Name: "User1"}) _ = c.Insert(ctx, &User{Name: "User2"}) _ = c.Insert(ctx, &User{Name: "User3"}) var lengths []int var users []User err = c.QueryChunks(ctx, ChunkParser{ Query: `select * from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, Params: []interface{}{"User%"}, ChunkSize: 2, ForEachChunk: func(buffer []User) error { lengths = append(lengths, len(buffer)) users = append(users, buffer...) return ErrAbortIteration }, }) assert.Equal(t, nil, err) assert.Equal(t, 2, len(users)) assert.NotEqual(t, uint(0), users[0].ID) assert.Equal(t, "User1", users[0].Name) assert.NotEqual(t, uint(0), users[1].ID) assert.Equal(t, "User2", users[1].Name) assert.Equal(t, []int{2}, lengths) }) t.Run("should abort the last iteration when the callback returns an ErrAbortIteration", func(t *testing.T) { err := createTable(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } db := connectDB(t, driver) defer db.Close() ctx := context.Background() c := newTestDB(db, driver, "users") _ = c.Insert(ctx, &User{Name: "User1"}) _ = c.Insert(ctx, &User{Name: "User2"}) _ = c.Insert(ctx, &User{Name: "User3"}) returnVals := []error{nil, ErrAbortIteration} var lengths []int var users []User err = c.QueryChunks(ctx, ChunkParser{ Query: `select * from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, Params: []interface{}{"User%"}, ChunkSize: 2, ForEachChunk: func(buffer []User) error { lengths = append(lengths, len(buffer)) users = append(users, buffer...) return shiftErrSlice(&returnVals) }, }) assert.Equal(t, nil, err) assert.Equal(t, 3, len(users)) assert.NotEqual(t, uint(0), users[0].ID) assert.Equal(t, "User1", users[0].Name) assert.NotEqual(t, uint(0), users[1].ID) assert.Equal(t, "User2", users[1].Name) assert.NotEqual(t, uint(0), users[2].ID) assert.Equal(t, "User3", users[2].Name) assert.Equal(t, []int{2, 1}, lengths) }) t.Run("should return error if the callback returns an error in the first iteration", func(t *testing.T) { err := createTable(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } db := connectDB(t, driver) defer db.Close() ctx := context.Background() c := newTestDB(db, driver, "users") _ = c.Insert(ctx, &User{Name: "User1"}) _ = c.Insert(ctx, &User{Name: "User2"}) _ = c.Insert(ctx, &User{Name: "User3"}) var lengths []int var users []User err = c.QueryChunks(ctx, ChunkParser{ Query: `select * from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, Params: []interface{}{"User%"}, ChunkSize: 2, ForEachChunk: func(buffer []User) error { lengths = append(lengths, len(buffer)) users = append(users, buffer...) return errors.New("fake error msg") }, }) assert.NotEqual(t, nil, err) assert.Equal(t, 2, len(users)) assert.NotEqual(t, uint(0), users[0].ID) assert.Equal(t, "User1", users[0].Name) assert.NotEqual(t, uint(0), users[1].ID) assert.Equal(t, "User2", users[1].Name) assert.Equal(t, []int{2}, lengths) }) t.Run("should return error if the callback returns an error in the last iteration", func(t *testing.T) { err := createTable(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } db := connectDB(t, driver) defer db.Close() ctx := context.Background() c := newTestDB(db, driver, "users") _ = c.Insert(ctx, &User{Name: "User1"}) _ = c.Insert(ctx, &User{Name: "User2"}) _ = c.Insert(ctx, &User{Name: "User3"}) returnVals := []error{nil, errors.New("fake error msg")} var lengths []int var users []User err = c.QueryChunks(ctx, ChunkParser{ Query: `select * from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, Params: []interface{}{"User%"}, ChunkSize: 2, ForEachChunk: func(buffer []User) error { lengths = append(lengths, len(buffer)) users = append(users, buffer...) return shiftErrSlice(&returnVals) }, }) assert.NotEqual(t, nil, err) assert.Equal(t, 3, len(users)) assert.NotEqual(t, uint(0), users[0].ID) assert.Equal(t, "User1", users[0].Name) assert.NotEqual(t, uint(0), users[1].ID) assert.Equal(t, "User2", users[1].Name) assert.NotEqual(t, uint(0), users[2].ID) assert.Equal(t, "User3", users[2].Name) assert.Equal(t, []int{2, 1}, lengths) }) t.Run("should report error if the input function is invalid", func(t *testing.T) { db := connectDB(t, driver) defer db.Close() ctx := context.Background() c := newTestDB(db, driver, "users") funcs := []interface{}{ nil, "not a function", func() error { return nil }, func(extraInputValue []User, extra []User) error { return nil }, func(invalidArgType string) error { return nil }, func(missingReturnType []User) { return }, func(users []User) string { return "" }, func(extraReturnValue []User) ([]User, error) { return nil, nil }, func(notSliceOfStructs []string) error { return nil }, } for _, fn := range funcs { err := c.QueryChunks(ctx, ChunkParser{ Query: `SELECT * FROM users`, Params: []interface{}{}, ChunkSize: 2, ForEachChunk: fn, }) assert.NotEqual(t, nil, err) } }) t.Run("should report error if the query is not valid", func(t *testing.T) { db := connectDB(t, driver) defer db.Close() ctx := context.Background() c := newTestDB(db, driver, "users") err := c.QueryChunks(ctx, ChunkParser{ Query: `SELECT * FROM not a valid query`, Params: []interface{}{}, ChunkSize: 2, ForEachChunk: func(buffer []User) error { return nil }, }) assert.NotEqual(t, nil, err) }) }) } } func TestTransaction(t *testing.T) { for _, driver := range []string{"sqlite3", "postgres"} { t.Run(driver, func(t *testing.T) { t.Run("should query a single row correctly", func(t *testing.T) { err := createTable(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } db := connectDB(t, driver) defer db.Close() ctx := context.Background() c := newTestDB(db, driver, "users") _ = c.Insert(ctx, &User{Name: "User1"}) _ = c.Insert(ctx, &User{Name: "User2"}) var users []User err = c.Transaction(ctx, func(db SQLProvider) error { db.Query(ctx, &users, "SELECT * FROM users ORDER BY id ASC") return nil }) assert.Equal(t, nil, err) assert.Equal(t, 2, len(users)) assert.Equal(t, "User1", users[0].Name) assert.Equal(t, "User2", users[1].Name) }) t.Run("should rollback when there are errors", func(t *testing.T) { err := createTable(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } db := connectDB(t, driver) defer db.Close() ctx := context.Background() c := newTestDB(db, driver, "users") u1 := User{Name: "User1", Age: 42} u2 := User{Name: "User2", Age: 42} _ = c.Insert(ctx, &u1) _ = c.Insert(ctx, &u2) err = c.Transaction(ctx, func(db SQLProvider) error { err = db.Insert(ctx, &User{Name: "User3"}) assert.Equal(t, nil, err) err = db.Insert(ctx, &User{Name: "User4"}) assert.Equal(t, nil, err) err = db.Exec(ctx, "UPDATE users SET age = 22") assert.Equal(t, nil, err) return errors.New("fake-error") }) assert.NotEqual(t, nil, err) assert.Equal(t, "fake-error", err.Error()) var users []User err = c.Query(ctx, &users, "SELECT * FROM users ORDER BY id ASC") assert.Equal(t, nil, err) assert.Equal(t, []User{u1, u2}, users) }) }) } } func TestScanRows(t *testing.T) { t.Run("should scan users correctly", func(t *testing.T) { err := createTable("sqlite3") if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } ctx := context.TODO() db := connectDB(t, "sqlite3") defer db.Close() c := newTestDB(db, "sqlite3", "users") _ = c.Insert(ctx, &User{Name: "User1", Age: 22}) _ = c.Insert(ctx, &User{Name: "User2", Age: 14}) _ = c.Insert(ctx, &User{Name: "User3", Age: 43}) rows, err := db.QueryContext(ctx, "select * from users where name='User2'") assert.Equal(t, nil, err) defer rows.Close() assert.Equal(t, true, rows.Next()) var u User err = scanRows(rows, &u) assert.Equal(t, nil, err) assert.Equal(t, "User2", u.Name) assert.Equal(t, 14, u.Age) }) t.Run("should ignore extra columns from query", func(t *testing.T) { err := createTable("sqlite3") if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } ctx := context.TODO() db := connectDB(t, "sqlite3") defer db.Close() c := newTestDB(db, "sqlite3", "users") _ = c.Insert(ctx, &User{Name: "User1", Age: 22}) rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User1'") assert.Equal(t, nil, err) defer rows.Close() assert.Equal(t, true, rows.Next()) var user struct { ID int `kisssql:"id"` Age int `kisssql:"age"` // Omitted for testing purposes: // Name string `kisssql:"name"` } err = scanRows(rows, &user) assert.Equal(t, nil, err) assert.Equal(t, 22, user.Age) }) t.Run("should report error for closed rows", func(t *testing.T) { err := createTable("sqlite3") if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } ctx := context.TODO() db := connectDB(t, "sqlite3") defer db.Close() rows, err := db.QueryContext(ctx, "select * from users where name='User2'") assert.Equal(t, nil, err) var u User err = rows.Close() assert.Equal(t, nil, err) err = scanRows(rows, &u) assert.NotEqual(t, nil, err) }) t.Run("should report if record is not a pointer", func(t *testing.T) { err := createTable("sqlite3") if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } ctx := context.TODO() db := connectDB(t, "sqlite3") defer db.Close() rows, err := db.QueryContext(ctx, "select * from users where name='User2'") assert.Equal(t, nil, err) var u User err = scanRows(rows, u) assert.NotEqual(t, nil, err) }) t.Run("should report if record is not a pointer to struct", func(t *testing.T) { err := createTable("sqlite3") if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } ctx := context.TODO() db := connectDB(t, "sqlite3") defer db.Close() rows, err := db.QueryContext(ctx, "select * from users where name='User2'") assert.Equal(t, nil, err) var u map[string]interface{} err = scanRows(rows, &u) assert.NotEqual(t, nil, err) }) } var connectionString = map[string]string{ "postgres": "host=localhost port=5432 user=postgres password=postgres dbname=kisssql sslmode=disable", "sqlite3": "/tmp/kisssql.db", } func createTable(driver string) error { connStr := connectionString[driver] if connStr == "" { return fmt.Errorf("unsupported driver: '%s'", driver) } db, err := sql.Open(driver, connStr) if err != nil { return err } defer db.Close() db.Exec(`DROP TABLE users`) switch driver { case "sqlite3": _, err = db.Exec(`CREATE TABLE users ( id INTEGER PRIMARY KEY, age INTEGER, name TEXT, address BLOB )`) case "postgres": _, err = db.Exec(`CREATE TABLE users ( id serial PRIMARY KEY, age INT, name VARCHAR(50), address jsonb )`) } if err != nil { return fmt.Errorf("failed to create new users table: %s", err.Error()) } return nil } func newTestDB(db *sql.DB, driver string, tableName string, ids ...string) DB { if len(ids) == 0 { ids = []string{"id"} } return DB{ driver: driver, dialect: getDriverDialect(driver), db: db, tableName: tableName, idCols: ids, insertMethod: map[string]insertMethod{ "sqlite3": insertWithLastInsertID, "postgres": insertWithReturning, }[driver], } } func connectDB(t *testing.T, driver string) *sql.DB { connStr := connectionString[driver] if connStr == "" { panic(fmt.Sprintf("unsupported driver: '%s'", driver)) } db, err := sql.Open(driver, connStr) if err != nil { t.Fatal(err.Error()) } return db } func shiftErrSlice(errs *[]error) error { err := (*errs)[0] *errs = (*errs)[1:] return err } func getUsersByID(dbi sqlProvider, dialect dialect, resultsPtr *[]User, ids ...uint) error { db := dbi.(*sql.DB) placeholders := make([]string, len(ids)) params := make([]interface{}, len(ids)) for i := range ids { params[i] = ids[i] placeholders[i] = dialect.Placeholder(i) } results := []User{} rows, err := db.Query( fmt.Sprintf( "SELECT id, name, age FROM users WHERE id IN (%s)", strings.Join(placeholders, ", "), ), params..., ) if err != nil { return err } defer rows.Close() for rows.Next() { var u User err = rows.Scan(&u.ID, &u.Name, &u.Age) if err != nil { return err } results = append(results, u) } if rows.Err() != nil { return rows.Err() } if err := rows.Close(); err != nil { return err } *resultsPtr = results return nil } func getUserByID(dbi sqlProvider, dialect dialect, result *User, id uint) error { db := dbi.(*sql.DB) row := db.QueryRow(`SELECT id, name, age, address FROM users WHERE id=`+dialect.Placeholder(0), id) if row.Err() != nil { return row.Err() } var rawAddr []byte err := row.Scan(&result.ID, &result.Name, &result.Age, &rawAddr) if err != nil { return err } if rawAddr == nil { return nil } return json.Unmarshal(rawAddr, &result.Address) } func getUserByName(dbi sqlProvider, dialect dialect, result *User, name string) error { db := dbi.(*sql.DB) row := db.QueryRow(`SELECT id, name, age, address FROM users WHERE name=`+dialect.Placeholder(0), name) if row.Err() != nil { return row.Err() } var rawAddr []byte err := row.Scan(&result.ID, &result.Name, &result.Age, &rawAddr) if err != nil { return err } if rawAddr == nil { return nil } return json.Unmarshal(rawAddr, &result.Address) }