ksql/test_adapters.go

3766 lines
108 KiB
Go

package ksql
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"io"
"reflect"
"strings"
"testing"
"time"
"github.com/vingarcia/ksql/internal/modifiers"
"github.com/vingarcia/ksql/internal/structs"
tt "github.com/vingarcia/ksql/internal/testtools"
"github.com/vingarcia/ksql/ksqlmodifiers"
"github.com/vingarcia/ksql/nullable"
"github.com/vingarcia/ksql/sqldialect"
)
var usersTable = NewTable("users")
type user struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
Age int `ksql:"age"`
Address address `ksql:"address,json"`
// This attr has no ksql tag, thus, it should be ignored:
AttrThatShouldBeIgnored string
}
type address struct {
Street string `json:"street"`
Number string `json:"number"`
City string `json:"city"`
State string `json:"state"`
Country string `json:"country"`
}
var postsTable = NewTable("posts")
type post struct {
ID int `ksql:"id"`
UserID uint `ksql:"user_id"`
Title string `ksql:"title"`
}
var userPermissionsTable = NewTable("user_permissions", "user_id", "perm_id")
type userPermission struct {
ID int `ksql:"id"`
UserID int `ksql:"user_id"`
PermID int `ksql:"perm_id"`
Type string `ksql:"type"`
}
// RunTestsForAdapter will run all necessary tests for making sure
// a given adapter is working as expected.
//
// Optionally it is also possible to run each of these tests
// separatedly, which might be useful during the development
// of a new adapter.
func RunTestsForAdapter(
t *testing.T,
adapterName string,
dialect sqldialect.Provider,
connStr string,
newDBAdapter func(t *testing.T) (DBAdapter, io.Closer),
) {
t.Run(adapterName, func(t *testing.T) {
t.Run(dialect.DriverName(), func(t *testing.T) {
QueryTest(t, dialect, connStr, newDBAdapter)
QueryOneTest(t, dialect, connStr, newDBAdapter)
InsertTest(t, dialect, connStr, newDBAdapter)
DeleteTest(t, dialect, connStr, newDBAdapter)
PatchTest(t, dialect, connStr, newDBAdapter)
QueryChunksTest(t, dialect, connStr, newDBAdapter)
TransactionTest(t, dialect, connStr, newDBAdapter)
ModifiersTest(t, dialect, connStr, newDBAdapter)
ScanRowsTest(t, dialect, connStr, newDBAdapter)
})
})
}
// QueryTest runs all tests for making sure the Query function is
// working for a given adapter and dialect.
func QueryTest(
t *testing.T,
dialect sqldialect.Provider,
connStr string,
newDBAdapter func(t *testing.T) (DBAdapter, io.Closer),
) {
ctx := context.Background()
t.Run("Query", func(t *testing.T) {
variations := []struct {
desc string
queryPrefix string
}{
{
desc: "with select *",
queryPrefix: "SELECT * ",
},
{
desc: "building the SELECT part of the query internally",
queryPrefix: "",
},
}
for _, variation := range variations {
t.Run(variation.desc, func(t *testing.T) {
t.Run("using slice of structs", func(t *testing.T) {
err := createTables(dialect, connStr)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
t.Run("should return 0 results correctly", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
var users []user
err := c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE id=1;`)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, len(users), 0)
users = []user{}
err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE id=1;`)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, len(users), 0)
})
t.Run("should return a user correctly", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Bia', 0, '{"country":"BR"}')`)
tt.AssertNoErr(t, err)
c := newTestDB(db, dialect)
var users []user
err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia")
tt.AssertNoErr(t, err)
tt.AssertEqual(t, len(users), 1)
tt.AssertNotEqual(t, users[0].ID, uint(0))
tt.AssertEqual(t, users[0].Name, "Bia")
tt.AssertEqual(t, users[0].Address.Country, "BR")
})
t.Run("should return multiple users correctly", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('João Garcia', 0, '{"country":"US"}')`)
tt.AssertNoErr(t, err)
_, err = db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Bia Garcia', 0, '{"country":"BR"}')`)
tt.AssertNoErr(t, err)
c := newTestDB(db, dialect)
var users []user
err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name like `+c.dialect.Placeholder(0), "% Garcia")
tt.AssertNoErr(t, err)
tt.AssertEqual(t, len(users), 2)
tt.AssertNotEqual(t, users[0].ID, uint(0))
tt.AssertEqual(t, users[0].Name, "João Garcia")
tt.AssertEqual(t, users[0].Address.Country, "US")
tt.AssertNotEqual(t, users[1].ID, uint(0))
tt.AssertEqual(t, users[1].Name, "Bia Garcia")
tt.AssertEqual(t, users[1].Address.Country, "BR")
})
t.Run("should query joined tables correctly", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
// This test only makes sense with no query prefix
if variation.queryPrefix != "" {
return
}
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('João Ribeiro', 0, '{"country":"US"}')`)
tt.AssertNoErr(t, err)
var joao user
getUserByName(db, dialect, &joao, "João Ribeiro")
tt.AssertNoErr(t, err)
_, err = db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Bia Ribeiro', 0, '{"country":"BR"}')`)
tt.AssertNoErr(t, err)
var bia user
getUserByName(db, dialect, &bia, "Bia Ribeiro")
_, err = db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, bia.ID, `, 'Bia Post1')`))
tt.AssertNoErr(t, err)
_, err = db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, bia.ID, `, 'Bia Post2')`))
tt.AssertNoErr(t, err)
_, err = db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, joao.ID, `, 'João Post1')`))
tt.AssertNoErr(t, err)
ctx := context.Background()
c := newTestDB(db, dialect)
var rows []struct {
User user `tablename:"u"`
Post post `tablename:"p"`
// This one has no ksql or tablename tag,
// so it should just be ignored to avoid strange
// unexpected errors:
ExtraStructThatShouldBeIgnored user
}
err = c.Query(ctx, &rows, fmt.Sprint(
`FROM users u JOIN posts p ON p.user_id = u.id`,
` WHERE u.name like `, c.dialect.Placeholder(0),
` ORDER BY u.id, p.id`,
), "% Ribeiro")
tt.AssertNoErr(t, err)
tt.AssertEqual(t, len(rows), 3)
tt.AssertEqual(t, rows[0].User.ID, joao.ID)
tt.AssertEqual(t, rows[0].User.Name, "João Ribeiro")
tt.AssertEqual(t, rows[0].Post.Title, "João Post1")
tt.AssertNotEqual(t, rows[0].Post.ID, 0)
tt.AssertEqual(t, rows[1].User.ID, bia.ID)
tt.AssertEqual(t, rows[1].User.Name, "Bia Ribeiro")
tt.AssertEqual(t, rows[1].Post.Title, "Bia Post1")
tt.AssertNotEqual(t, rows[1].Post.ID, 0)
tt.AssertEqual(t, rows[2].User.ID, bia.ID)
tt.AssertEqual(t, rows[2].User.Name, "Bia Ribeiro")
tt.AssertEqual(t, rows[2].Post.Title, "Bia Post2")
tt.AssertNotEqual(t, rows[2].Post.ID, 0)
})
})
t.Run("using slice of pointers to structs", func(t *testing.T) {
err := createTables(dialect, connStr)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
t.Run("should return 0 results correctly", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
var users []*user
err := c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE id=1;`)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, len(users), 0)
users = []*user{}
err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE id=1;`)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, len(users), 0)
})
t.Run("should return a user correctly", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Bia', 0, '{"country":"BR"}')`)
tt.AssertNoErr(t, err)
c := newTestDB(db, dialect)
var users []*user
err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia")
tt.AssertNoErr(t, err)
tt.AssertEqual(t, len(users), 1)
tt.AssertNotEqual(t, users[0].ID, uint(0))
tt.AssertEqual(t, users[0].Name, "Bia")
tt.AssertEqual(t, users[0].Address.Country, "BR")
})
t.Run("should return multiple users correctly", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('João Garcia', 0, '{"country":"US"}')`)
tt.AssertNoErr(t, err)
_, err = db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Bia Garcia', 0, '{"country":"BR"}')`)
tt.AssertNoErr(t, err)
c := newTestDB(db, dialect)
var users []*user
err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name like `+c.dialect.Placeholder(0), "% Garcia")
tt.AssertNoErr(t, err)
tt.AssertEqual(t, len(users), 2)
tt.AssertNotEqual(t, users[0].ID, uint(0))
tt.AssertEqual(t, users[0].Name, "João Garcia")
tt.AssertEqual(t, users[0].Address.Country, "US")
tt.AssertNotEqual(t, users[1].ID, uint(0))
tt.AssertEqual(t, users[1].Name, "Bia Garcia")
tt.AssertEqual(t, users[1].Address.Country, "BR")
})
t.Run("should query joined tables correctly", func(t *testing.T) {
// This test only makes sense with no query prefix
if variation.queryPrefix != "" {
return
}
db, closer := newDBAdapter(t)
defer closer.Close()
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('João Ribeiro', 0, '{"country":"US"}')`)
tt.AssertNoErr(t, err)
var joao user
getUserByName(db, dialect, &joao, "João Ribeiro")
_, err = db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Bia Ribeiro', 0, '{"country":"BR"}')`)
tt.AssertNoErr(t, err)
var bia user
getUserByName(db, dialect, &bia, "Bia Ribeiro")
_, err = db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, bia.ID, `, 'Bia Post1')`))
tt.AssertNoErr(t, err)
_, err = db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, bia.ID, `, 'Bia Post2')`))
tt.AssertNoErr(t, err)
_, err = db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, joao.ID, `, 'João Post1')`))
tt.AssertNoErr(t, err)
c := newTestDB(db, dialect)
var rows []*struct {
User user `tablename:"u"`
Post post `tablename:"p"`
}
err = c.Query(ctx, &rows, fmt.Sprint(
`FROM users u JOIN posts p ON p.user_id = u.id`,
` WHERE u.name like `, c.dialect.Placeholder(0),
` ORDER BY u.id, p.id`,
), "% Ribeiro")
tt.AssertNoErr(t, err)
tt.AssertEqual(t, len(rows), 3)
tt.AssertEqual(t, rows[0].User.ID, joao.ID)
tt.AssertEqual(t, rows[0].User.Name, "João Ribeiro")
tt.AssertEqual(t, rows[0].Post.Title, "João Post1")
tt.AssertEqual(t, rows[1].User.ID, bia.ID)
tt.AssertEqual(t, rows[1].User.Name, "Bia Ribeiro")
tt.AssertEqual(t, rows[1].Post.Title, "Bia Post1")
tt.AssertEqual(t, rows[2].User.ID, bia.ID)
tt.AssertEqual(t, rows[2].User.Name, "Bia Ribeiro")
tt.AssertEqual(t, rows[2].Post.Title, "Bia Post2")
})
})
})
}
t.Run("testing error cases", func(t *testing.T) {
err := createTables(dialect, connStr)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
t.Run("should report error if input is not a pointer to a slice of structs", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Andréa Sá', 0)`)
tt.AssertNoErr(t, err)
_, err = db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Caio Sá', 0)`)
tt.AssertNoErr(t, err)
c := newTestDB(db, dialect)
err = c.Query(ctx, &user{}, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Sá")
tt.AssertErrContains(t, err, "expected", "to be a slice", "user")
err = c.Query(ctx, []*user{}, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Sá")
tt.AssertErrContains(t, err, "expected", "slice of structs", "user")
var i int
err = c.Query(ctx, &i, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Sá")
tt.AssertErrContains(t, err, "expected", "to be a slice", "int")
err = c.Query(ctx, &[]int{}, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Sá")
tt.AssertErrContains(t, err, "expected", "slice of structs", "[]int")
})
t.Run("should report error if the query is not valid", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
var users []user
err := c.Query(ctx, &users, `SELECT * FROM not a valid query`)
tt.AssertErrContains(t, err, "error running query")
})
t.Run("should report error if the TagInfoCache returns an error", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
// Provoque an error by sending an invalid struct:
var users []struct {
ID int `ksql:"id"`
// Private names cannot have ksql tags:
badPrivateField string `ksql:"name"`
}
err := c.Query(ctx, &users, `SELECT * FROM users`)
tt.AssertErrContains(t, err, "badPrivateField")
})
t.Run("should report error if using nested struct and the query starts with SELECT", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
var rows []struct {
User user `tablename:"users"`
Post post `tablename:"posts"`
}
err := c.Query(ctx, &rows, `SELECT * FROM users u JOIN posts p ON u.id = p.user_id`)
tt.AssertErrContains(t, err, "nested struct", "feature")
})
t.Run("should report error for nested structs with invalid types", func(t *testing.T) {
t.Run("int", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
var rows []struct {
Foo int `tablename:"foo"`
}
err := c.Query(ctx, &rows, fmt.Sprint(
`FROM users u JOIN posts p ON p.user_id = u.id`,
` WHERE u.name like `, c.dialect.Placeholder(0),
` ORDER BY u.id, p.id`,
), "% Ribeiro")
tt.AssertErrContains(t, err, "foo", "int")
})
t.Run("*struct", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
var rows []struct {
Foo *user `tablename:"foo"`
}
err := c.Query(ctx, &rows, fmt.Sprint(
`FROM users u JOIN posts p ON p.user_id = u.id`,
` WHERE u.name like `, c.dialect.Placeholder(0),
` ORDER BY u.id, p.id`,
), "% Ribeiro")
tt.AssertErrContains(t, err, "foo", "*ksql.user")
})
})
t.Run("should report error if nested struct is invalid", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
var rows []struct {
User user `tablename:"users"`
Post struct {
Attr1 int `ksql:"invalid_repeated_name"`
Attr2 int `ksql:"invalid_repeated_name"`
} `tablename:"posts"`
}
err := c.Query(ctx, &rows, `FROM users u JOIN posts p ON u.id = p.user_id`)
tt.AssertErrContains(t, err, "same ksql tag name", "invalid_repeated_name")
})
t.Run("should report error if DBAdapter.Scan() returns an error", func(t *testing.T) {
db := mockDBAdapter{
QueryContextFn: func(ctx context.Context, query string, args ...interface{}) (Rows, error) {
return mockRows{
ColumnsFn: func() ([]string, error) {
return []string{"id", "name", "age", "address"}, nil
},
NextFn: func() bool { return true },
ScanFn: func(values ...interface{}) error {
return fmt.Errorf("fakeScanErr")
},
}, nil
},
}
c := newTestDB(db, dialect)
var users []user
err := c.Query(ctx, &users, `SELECT * FROM users`)
tt.AssertErrContains(t, err, "KSQL", "scan error", "fakeScanErr")
})
t.Run("should report error if DBAdapter.Err() returns an error", func(t *testing.T) {
db := mockDBAdapter{
QueryContextFn: func(ctx context.Context, query string, args ...interface{}) (Rows, error) {
return mockRows{
ColumnsFn: func() ([]string, error) {
return []string{"id", "name", "age", "address"}, nil
},
NextFn: func() bool { return false },
ScanFn: func(values ...interface{}) error {
return nil
},
ErrFn: func() error {
return fmt.Errorf("fakeErrMsg")
},
}, nil
},
}
c := newTestDB(db, dialect)
var users []user
err := c.Query(ctx, &users, `SELECT * FROM users`)
tt.AssertErrContains(t, err, "KSQL", "fakeErrMsg")
})
t.Run("should report error if DBAdapter.Close() returns an error", func(t *testing.T) {
db := mockDBAdapter{
QueryContextFn: func(ctx context.Context, query string, args ...interface{}) (Rows, error) {
return mockRows{
ColumnsFn: func() ([]string, error) {
return []string{"id", "name", "age", "address"}, nil
},
NextFn: func() bool { return false },
ScanFn: func(values ...interface{}) error {
return nil
},
CloseFn: func() error {
return fmt.Errorf("fakeCloseErr")
},
}, nil
},
}
c := newTestDB(db, dialect)
var users []user
err := c.Query(ctx, &users, `SELECT * FROM users`)
tt.AssertErrContains(t, err, "KSQL", "fakeCloseErr")
})
t.Run("should report error context.Canceled if the context is canceled", func(t *testing.T) {
ctx, cancel := context.WithCancel(ctx)
cancel()
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
var users []user
err := c.Query(ctx, &users, `SELECT * FROM users`)
tt.AssertErrContains(t, err, "context", "canceled")
tt.AssertEqual(t, errors.Is(err, context.Canceled), true)
})
})
})
}
// QueryOneTest runs all tests for making sure the QueryOne function is
// working for a given adapter and dialect.
func QueryOneTest(
t *testing.T,
dialect sqldialect.Provider,
connStr string,
newDBAdapter func(t *testing.T) (DBAdapter, io.Closer),
) {
ctx := context.Background()
t.Run("QueryOne", func(t *testing.T) {
variations := []struct {
desc string
queryPrefix string
}{
{
desc: "with select *",
queryPrefix: "SELECT * ",
},
{
desc: "building the SELECT part of the query internally",
queryPrefix: "",
},
}
for _, variation := range variations {
err := createTables(dialect, connStr)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
t.Run(variation.desc, func(t *testing.T) {
t.Run("should return RecordNotFoundErr when there are no results", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
u := user{}
err := c.QueryOne(ctx, &u, variation.queryPrefix+`FROM users WHERE id=1;`)
tt.AssertEqual(t, err, ErrRecordNotFound)
})
t.Run("should return a user correctly", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Bia', 0, '{"country":"BR"}')`)
tt.AssertNoErr(t, err)
c := newTestDB(db, dialect)
u := user{}
err = c.QueryOne(ctx, &u, variation.queryPrefix+`FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia")
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, u.ID, uint(0))
tt.AssertEqual(t, u.Name, "Bia")
tt.AssertEqual(t, u.Address, address{
Country: "BR",
})
})
t.Run("should return only the first result on multiples matches", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Andréa Sá', 0, '{"country":"US"}')`)
tt.AssertNoErr(t, err)
_, err = db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Caio Sá', 0, '{"country":"BR"}')`)
tt.AssertNoErr(t, err)
c := newTestDB(db, dialect)
var u user
err = c.QueryOne(ctx, &u, variation.queryPrefix+`FROM users WHERE name like `+c.dialect.Placeholder(0)+` ORDER BY id ASC`, "% Sá")
tt.AssertNoErr(t, err)
tt.AssertEqual(t, u.Name, "Andréa Sá")
tt.AssertEqual(t, u.Age, 0)
tt.AssertEqual(t, u.Address, address{
Country: "US",
})
})
t.Run("should query joined tables correctly", func(t *testing.T) {
// This test only makes sense with no query prefix
if variation.queryPrefix != "" {
return
}
db, closer := newDBAdapter(t)
defer closer.Close()
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('João Ribeiro', 0, '{"country":"US"}')`)
tt.AssertNoErr(t, err)
var joao user
getUserByName(db, dialect, &joao, "João Ribeiro")
_, err = db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, joao.ID, `, 'João Post1')`))
tt.AssertNoErr(t, err)
c := newTestDB(db, dialect)
var row struct {
User user `tablename:"u"`
Post post `tablename:"p"`
}
err = c.QueryOne(ctx, &row, fmt.Sprint(
`FROM users u JOIN posts p ON p.user_id = u.id`,
` WHERE u.name like `, c.dialect.Placeholder(0),
` ORDER BY u.id, p.id`,
), "% Ribeiro")
tt.AssertNoErr(t, err)
tt.AssertEqual(t, row.User.ID, joao.ID)
tt.AssertEqual(t, row.User.Name, "João Ribeiro")
tt.AssertEqual(t, row.Post.Title, "João Post1")
})
t.Run("should handle column tags as case-insensitive as SQL does", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Count Olivia', 0, '{"country":"US"}')`)
tt.AssertNoErr(t, err)
c := newTestDB(db, dialect)
var row struct {
Count int `ksql:"myCount"`
}
err = c.QueryOne(ctx, &row, `SELECT count(*) as myCount FROM users WHERE name='Count Olivia'`)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, row.Count, 1)
})
})
}
t.Run("should report error if input is not a pointer to struct", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Andréa Sá', 0, '{"country":"US"}')`)
tt.AssertNoErr(t, err)
_, err = db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Caio Sá', 0, '{"country":"BR"}')`)
tt.AssertNoErr(t, err)
c := newTestDB(db, dialect)
err = c.QueryOne(ctx, &[]user{}, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Sá")
tt.AssertErrContains(t, err, "pointer to struct")
err = c.QueryOne(ctx, user{}, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Sá")
tt.AssertErrContains(t, err, "pointer to struct")
})
t.Run("should report error if it receives a nil pointer to a struct", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
var u *user
err := c.QueryOne(ctx, u, `SELECT * FROM users`)
tt.AssertErrContains(t, err, "expected a valid pointer", "received a nil pointer")
})
t.Run("should report error if the query is not valid", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
var u user
err := c.QueryOne(ctx, &u, `SELECT * FROM not a valid query`)
tt.AssertErrContains(t, err, "error running query")
})
t.Run("should report error if using nested struct and the query starts with SELECT", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
var row struct {
User user `tablename:"users"`
Post post `tablename:"posts"`
}
err := c.QueryOne(ctx, &row, `SELECT * FROM users u JOIN posts p ON u.id = p.user_id LIMIT 1`)
tt.AssertErrContains(t, err, "nested struct", "feature")
})
t.Run("should report error if a private field has a ksql tag", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Olivia', 0, '{"country":"US"}')`)
tt.AssertNoErr(t, err)
c := newTestDB(db, dialect)
var row struct {
count int `ksql:"my_count"`
}
err = c.QueryOne(ctx, &row, `SELECT count(*) as my_count FROM users`)
tt.AssertErrContains(t, err, "unexported", "my_count")
})
t.Run("should report error context.Canceled the context has been canceled", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
ctx, cancel := context.WithCancel(ctx)
cancel()
c := newTestDB(db, dialect)
var u user
err := c.QueryOne(ctx, &u, `SELECT * FROM users LIMIT 1`)
tt.AssertEqual(t, errors.Is(err, context.Canceled), true)
})
})
}
// InsertTest runs all tests for making sure the Insert function is
// working for a given adapter and dialect.
func InsertTest(
t *testing.T,
dialect sqldialect.Provider,
connStr string,
newDBAdapter func(t *testing.T) (DBAdapter, io.Closer),
) {
ctx := context.Background()
t.Run("Insert", func(t *testing.T) {
t.Run("success cases", func(t *testing.T) {
t.Run("single primary key tables", func(t *testing.T) {
err := createTables(dialect, connStr)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
t.Run("should insert one user correctly", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
u := user{
Name: "Fernanda",
Address: address{
Country: "Brazil",
},
}
err := c.Insert(ctx, usersTable, &u)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, u.ID, 0)
result := user{}
err = getUserByID(c.db, c.dialect, &result, u.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, result.Name, u.Name)
tt.AssertEqual(t, result.Address, u.Address)
})
t.Run("should insert ignoring the ID with multiple ids", func(t *testing.T) {
if dialect.InsertMethod() != sqldialect.InsertWithLastInsertID {
return
}
// Using columns "id" and "name" as IDs:
table := NewTable("users", "id", "name")
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
u := user{
Name: "No ID returned",
Age: 3434, // Random number to avoid false positives on this test
Address: address{
Country: "Brazil 3434",
},
}
err = c.Insert(ctx, table, &u)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, u.ID, uint(0))
result := user{}
err = getUserByName(c.db, dialect, &result, "No ID returned")
tt.AssertNoErr(t, err)
tt.AssertEqual(t, result.Age, u.Age)
tt.AssertEqual(t, result.Address, u.Address)
})
t.Run("should work with anonymous structs", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
err = c.Insert(ctx, usersTable, &struct {
ID int `ksql:"id"`
Name string `ksql:"name"`
Address map[string]interface{} `ksql:"address,json"`
}{Name: "fake-name", Address: map[string]interface{}{"city": "bar"}})
tt.AssertNoErr(t, err)
})
t.Run("should work with preset IDs", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
usersByName := NewTable("users", "name")
err = c.Insert(ctx, usersByName, &struct {
Name string `ksql:"name"`
Age int `ksql:"age"`
}{Name: "Preset Name", Age: 5455})
tt.AssertNoErr(t, err)
var inserted user
err := getUserByName(db, dialect, &inserted, "Preset Name")
tt.AssertNoErr(t, err)
tt.AssertEqual(t, inserted.Age, 5455)
})
t.Run("should work and retrieve the ID for structs with no attributes", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
type tsUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name,skipInserts"`
}
u := tsUser{
Name: "Letícia",
}
err := c.Insert(ctx, usersTable, &u)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, u.ID, 0)
var untaggedUser struct {
ID uint `ksql:"id"`
Name *string `ksql:"name"`
}
err = c.QueryOne(ctx, &untaggedUser, `FROM users WHERE id = `+c.dialect.Placeholder(0), u.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, untaggedUser.Name, (*string)(nil))
})
})
t.Run("composite key tables", func(t *testing.T) {
err := createTables(dialect, connStr)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
t.Run("should insert in composite key tables correctly", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
table := NewTable("user_permissions", "id", "user_id", "perm_id")
err = c.Insert(ctx, table, &userPermission{
UserID: 1,
PermID: 42,
})
tt.AssertNoErr(t, err)
userPerms, err := getUserPermissionsByUser(db, dialect, 1)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, len(userPerms), 1)
tt.AssertEqual(t, userPerms[0].UserID, 1)
tt.AssertEqual(t, userPerms[0].PermID, 42)
})
t.Run("should accept partially provided values for composite key tables", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
// Table defined with 3 values, but we'll provide only 2,
// the third will be generated for the purposes of this test:
table := NewTable("user_permissions", "id", "user_id", "perm_id")
permission := userPermission{
UserID: 2,
PermID: 42,
}
err = c.Insert(ctx, table, &permission)
tt.AssertNoErr(t, err)
userPerms, err := getUserPermissionsByUser(db, dialect, 2)
tt.AssertNoErr(t, err)
// Should retrieve the generated ID from the database,
// only if the database supports returning multiple values:
switch c.dialect.InsertMethod() {
case sqldialect.InsertWithNoIDRetrieval, sqldialect.InsertWithLastInsertID:
tt.AssertEqual(t, permission.ID, 0)
tt.AssertEqual(t, len(userPerms), 1)
tt.AssertEqual(t, userPerms[0].UserID, 2)
tt.AssertEqual(t, userPerms[0].PermID, 42)
case sqldialect.InsertWithReturning, sqldialect.InsertWithOutput:
tt.AssertNotEqual(t, permission.ID, 0)
tt.AssertEqual(t, len(userPerms), 1)
tt.AssertEqual(t, userPerms[0].ID, permission.ID)
tt.AssertEqual(t, userPerms[0].UserID, 2)
tt.AssertEqual(t, userPerms[0].PermID, 42)
}
})
t.Run("when inserting a struct with no values but composite keys should still retrieve the IDs", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
// Table defined with 3 values, but we'll provide only 2,
// the third will be generated for the purposes of this test:
table := NewTable("user_permissions", "id", "user_id", "perm_id")
type taggedPerm struct {
ID uint `ksql:"id"`
UserID int `ksql:"user_id"`
PermID int `ksql:"perm_id"`
Type string `ksql:"type,skipInserts"`
}
permission := taggedPerm{
UserID: 3,
PermID: 43,
}
err := c.Insert(ctx, table, &permission)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, permission.ID, 0)
var untaggedPerm struct {
ID uint `ksql:"id"`
UserID int `ksql:"user_id"`
PermID int `ksql:"perm_id"`
Type *string `ksql:"type"`
}
err = c.QueryOne(ctx, &untaggedPerm, `FROM user_permissions WHERE user_id = 3 AND perm_id = 43`)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, untaggedPerm.Type, (*string)(nil))
// Should retrieve the generated ID from the database,
// only if the database supports returning multiple values:
switch c.dialect.InsertMethod() {
case sqldialect.InsertWithNoIDRetrieval, sqldialect.InsertWithLastInsertID:
tt.AssertEqual(t, permission.ID, uint(0))
tt.AssertEqual(t, untaggedPerm.UserID, 3)
tt.AssertEqual(t, untaggedPerm.PermID, 43)
case sqldialect.InsertWithReturning, sqldialect.InsertWithOutput:
tt.AssertEqual(t, untaggedPerm.ID, permission.ID)
tt.AssertEqual(t, untaggedPerm.UserID, 3)
tt.AssertEqual(t, untaggedPerm.PermID, 43)
}
})
})
})
t.Run("testing error cases", func(t *testing.T) {
err := createTables(dialect, connStr)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
t.Run("should report error for invalid input types", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
err = c.Insert(ctx, usersTable, "foo")
tt.AssertNotEqual(t, err, nil)
err = c.Insert(ctx, usersTable, nullable.String("foo"))
tt.AssertNotEqual(t, err, nil)
err = c.Insert(ctx, usersTable, map[string]interface{}{
"name": "foo",
"age": 12,
})
tt.AssertNotEqual(t, err, nil)
cantInsertSlice := []interface{}{
&user{Name: "foo", Age: 22},
&user{Name: "bar", Age: 32},
}
err = c.Insert(ctx, usersTable, cantInsertSlice)
tt.AssertNotEqual(t, err, nil)
// We might want to support this in the future, but not for now:
err = c.Insert(ctx, usersTable, user{Name: "not a ptr to user", Age: 42})
tt.AssertNotEqual(t, err, nil)
})
t.Run("should report error if for some reason the InsertMethod is invalid", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
// This is an invalid value:
c.dialect = brokenDialect{}
err = c.Insert(ctx, usersTable, &user{Name: "foo"})
tt.AssertNotEqual(t, err, nil)
})
t.Run("should report error if it receives a nil pointer to a struct", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
var u *user
err := c.Insert(ctx, usersTable, u)
tt.AssertNotEqual(t, err, nil)
})
t.Run("should report error if table contains an empty ID name", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
err := c.Insert(ctx, NewTable("users", ""), &user{Name: "fake-name"})
tt.AssertErrContains(t, err, "ksql.Table", "ID", "empty string")
})
t.Run("should report error if ksql.Table.name is empty", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
err := c.Insert(ctx, NewTable("", "id"), &user{Name: "fake-name"})
tt.AssertErrContains(t, err, "ksql.Table", "table name", "empty string")
})
t.Run("should not panic if a column doesn't exist in the database", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
err = c.Insert(ctx, usersTable, &struct {
ID string `ksql:"id"`
NonExistingColumn int `ksql:"non_existing"`
Name string `ksql:"name"`
}{NonExistingColumn: 42, Name: "fake-name"})
tt.AssertErrContains(t, err, "column", "non_existing")
})
t.Run("should not panic if the ID column doesn't exist in the database", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
brokenTable := NewTable("users", "non_existing_id")
_ = c.Insert(ctx, brokenTable, &struct {
ID string `ksql:"non_existing_id"`
Age int `ksql:"age"`
Name string `ksql:"name"`
}{Age: 42, Name: "fake-name"})
})
t.Run("should not panic if the ID column is missing in the struct", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
err = c.Insert(ctx, usersTable, &struct {
Age int `ksql:"age"`
Name string `ksql:"name"`
}{Age: 42, Name: "Inserted With no ID"})
tt.AssertNoErr(t, err)
var u user
err = getUserByName(db, dialect, &u, "Inserted With no ID")
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, u.ID, uint(0))
tt.AssertEqual(t, u.Age, 42)
})
t.Run("should report error context.Canceled the context has been canceled", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
ctx, cancel := context.WithCancel(ctx)
cancel()
c := newTestDB(db, dialect)
err = c.Insert(ctx, usersTable, &user{Age: 42, Name: "FakeUserName"})
tt.AssertEqual(t, errors.Is(err, context.Canceled), true)
})
})
})
}
type brokenDialect struct{}
func (brokenDialect) InsertMethod() sqldialect.InsertMethod {
return sqldialect.InsertMethod(42)
}
func (brokenDialect) Escape(str string) string {
return str
}
func (brokenDialect) Placeholder(idx int) string {
return "?"
}
func (brokenDialect) DriverName() string {
return "fake-driver-name"
}
// DeleteTest runs all tests for making sure the Delete function is
// working for a given adapter and dialect.
func DeleteTest(
t *testing.T,
dialect sqldialect.Provider,
connStr string,
newDBAdapter func(t *testing.T) (DBAdapter, io.Closer),
) {
ctx := context.Background()
t.Run("Delete", func(t *testing.T) {
err := createTables(dialect, connStr)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
t.Run("should delete from tables with a single primary key correctly", func(t *testing.T) {
tests := []struct {
desc string
deletionKeyForUser func(u user) interface{}
}{
{
desc: "passing only the ID as key",
deletionKeyForUser: func(u user) interface{} {
return u.ID
},
},
{
desc: "passing only the entire user",
deletionKeyForUser: func(u user) interface{} {
return u
},
},
{
desc: "passing the address of the user",
deletionKeyForUser: func(u user) interface{} {
return &u
},
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
u1 := user{
Name: "Fernanda",
}
err := c.Insert(ctx, usersTable, &u1)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, u1.ID, uint(0))
result := user{}
err = getUserByID(c.db, c.dialect, &result, u1.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, result.ID, u1.ID)
u2 := user{
Name: "Won't be deleted",
}
err = c.Insert(ctx, usersTable, &u2)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, u2.ID, uint(0))
result = user{}
err = getUserByID(c.db, c.dialect, &result, u2.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, result.ID, u2.ID)
err = c.Delete(ctx, usersTable, test.deletionKeyForUser(u1))
tt.AssertNoErr(t, err)
result = user{}
err = getUserByID(c.db, c.dialect, &result, u1.ID)
tt.AssertEqual(t, err, sql.ErrNoRows)
result = user{}
err = getUserByID(c.db, c.dialect, &result, u2.ID)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, result.ID, uint(0))
tt.AssertEqual(t, result.Name, "Won't be deleted")
})
}
})
t.Run("should delete from tables with composite primary keys correctly", func(t *testing.T) {
t.Run("using structs", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
// This permission should not be deleted, we'll use the id to check it:
p0 := userPermission{
UserID: 1,
PermID: 44,
}
err = createUserPermission(db, c.dialect, p0)
tt.AssertNoErr(t, err)
p0, err = getUserPermissionBySecondaryKeys(db, c.dialect, 1, 44)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, p0.ID, 0)
p1 := userPermission{
UserID: 1,
PermID: 42,
}
err = createUserPermission(db, c.dialect, p1)
tt.AssertNoErr(t, err)
err = c.Delete(ctx, userPermissionsTable, p1)
tt.AssertNoErr(t, err)
userPerms, err := getUserPermissionsByUser(db, dialect, 1)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, len(userPerms), 1)
tt.AssertEqual(t, userPerms[0].UserID, 1)
tt.AssertEqual(t, userPerms[0].PermID, 44)
})
t.Run("using maps", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
// This permission should not be deleted, we'll use the id to check it:
p0 := userPermission{
UserID: 2,
PermID: 44,
}
err = createUserPermission(db, c.dialect, p0)
tt.AssertNoErr(t, err)
p0, err = getUserPermissionBySecondaryKeys(db, c.dialect, 1, 44)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, p0.ID, 0)
p1 := userPermission{
UserID: 2,
PermID: 42,
}
err = createUserPermission(db, c.dialect, p1)
tt.AssertNoErr(t, err)
err = c.Delete(ctx, userPermissionsTable, map[string]interface{}{
"user_id": 2,
"perm_id": 42,
})
tt.AssertNoErr(t, err)
userPerms, err := getUserPermissionsByUser(db, dialect, 2)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, len(userPerms), 1)
tt.AssertEqual(t, userPerms[0].UserID, 2)
tt.AssertEqual(t, userPerms[0].PermID, 44)
})
})
t.Run("should return ErrRecordNotFound if no rows were deleted", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
err = c.Delete(ctx, usersTable, 4200)
tt.AssertEqual(t, err, ErrRecordNotFound)
})
t.Run("should report error if it receives a nil pointer to a struct", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
var u *user
err := c.Delete(ctx, usersTable, u)
tt.AssertNotEqual(t, err, nil)
})
t.Run("should report error if one of the ids is missing from the input", func(t *testing.T) {
t.Run("single id", func(t *testing.T) {
t.Run("struct with missing attr", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
err := c.Delete(ctx, NewTable("users", "id"), &struct {
// Missing ID
Name string `ksql:"name"`
}{Name: "fake-name"})
tt.AssertErrContains(t, err, "missing required", "id")
})
t.Run("struct with NULL attr", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
err := c.Delete(ctx, NewTable("users", "id"), &struct {
// Null ID
ID *int `ksql:"id"`
Name string `ksql:"name"`
}{Name: "fake-name"})
tt.AssertErrContains(t, err, "missing required", "id")
})
t.Run("struct with zero attr", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
err := c.Delete(ctx, NewTable("users", "id"), &struct {
// Uninitialized ID
ID int `ksql:"id"`
Name string `ksql:"name"`
}{Name: "fake-name"})
tt.AssertErrContains(t, err, "invalid value", "0", "id")
})
})
t.Run("multiple ids", func(t *testing.T) {
t.Run("struct with missing attr", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
err := c.Delete(ctx, userPermissionsTable, &struct {
// Missing PermID
UserID int `ksql:"user_id"`
Name string `ksql:"name"`
}{
UserID: 1,
Name: "fake-name",
})
tt.AssertErrContains(t, err, "missing required", "perm_id")
})
t.Run("map with missing attr", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
err := c.Delete(ctx, userPermissionsTable, map[string]interface{}{
// Missing PermID
"user_id": 1,
"name": "fake-name",
})
tt.AssertErrContains(t, err, "missing required", "perm_id")
})
t.Run("struct with NULL attr", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
err := c.Delete(ctx, userPermissionsTable, &struct {
UserID int `ksql:"user_id"`
PermID *int `ksql:"perm_id"`
Name string `ksql:"name"`
}{
// Null Perm ID
UserID: 1,
PermID: nil,
Name: "fake-name",
})
tt.AssertErrContains(t, err, "missing required", "perm_id")
})
t.Run("map with NULL attr", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
err := c.Delete(ctx, userPermissionsTable, map[string]interface{}{
// Null Perm ID
"user_id": 1,
"perm_id": nil,
"name": "fake-name",
})
tt.AssertErrContains(t, err, "invalid value", "nil", "perm_id")
})
t.Run("struct with zero attr", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
err := c.Delete(ctx, userPermissionsTable, &struct {
UserID int `ksql:"user_id"`
PermID int `ksql:"perm_id"`
Name string `ksql:"name"`
}{
// Zero Perm ID
UserID: 1,
PermID: 0,
Name: "fake-name",
})
tt.AssertErrContains(t, err, "invalid value", "0", "perm_id")
})
t.Run("map with zero attr", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
err := c.Delete(ctx, userPermissionsTable, map[string]interface{}{
// Zero Perm ID
"user_id": 1,
"perm_id": 0,
"name": "fake-name",
})
tt.AssertErrContains(t, err, "invalid value", "0", "perm_id")
})
})
})
t.Run("should report error if table contains an empty ID name", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
err := c.Delete(ctx, NewTable("users", ""), &user{ID: 42, Name: "fake-name"})
tt.AssertErrContains(t, err, "ksql.Table", "ID", "empty string")
})
t.Run("should report error if ksql.Table.name is empty", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
err := c.Delete(ctx, NewTable("", "id"), &user{Name: "fake-name"})
tt.AssertErrContains(t, err, "ksql.Table", "table name", "empty string")
})
t.Run("should report error context.Canceled the context has been canceled", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
ctx, cancel := context.WithCancel(ctx)
cancel()
c := newTestDB(db, dialect)
err := c.Delete(ctx, usersTable, &user{ID: 42, Name: "fake-name"})
tt.AssertEqual(t, errors.Is(err, context.Canceled), true)
})
})
}
// PatchTest runs all tests for making sure the Patch function is
// working for a given adapter and dialect.
func PatchTest(
t *testing.T,
dialect sqldialect.Provider,
connStr string,
newDBAdapter func(t *testing.T) (DBAdapter, io.Closer),
) {
ctx := context.Background()
t.Run("Patch", func(t *testing.T) {
err := createTables(dialect, connStr)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
t.Run("should update one user{} correctly", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
u := user{
Name: "Letícia",
}
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Letícia', 0)`)
tt.AssertNoErr(t, err)
err = getUserByName(db, dialect, &u, "Letícia")
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, u.ID, uint(0))
err = c.Update(ctx, usersTable, user{
ID: u.ID,
Name: "Thayane",
})
tt.AssertNoErr(t, err)
var result user
err = getUserByID(c.db, c.dialect, &result, u.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, result.Name, "Thayane")
})
t.Run("should update one &user{} correctly", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
u := user{
Name: "Letícia",
}
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Letícia', 0)`)
tt.AssertNoErr(t, err)
err = getUserByName(db, dialect, &u, "Letícia")
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, u.ID, uint(0))
err = c.Update(ctx, usersTable, &user{
ID: u.ID,
Name: "Thayane",
})
tt.AssertNoErr(t, err)
var result user
err = getUserByID(c.db, c.dialect, &result, u.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, result.Name, "Thayane")
})
t.Run("should update tables with composite keys correctly", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
err = createUserPermission(db, c.dialect, userPermission{
UserID: 42,
PermID: 43,
Type: "existingFakeType",
})
tt.AssertNoErr(t, err)
existingPerm, err := getUserPermissionBySecondaryKeys(db, c.dialect, 42, 43)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, existingPerm.ID, 0)
tt.AssertEqual(t, existingPerm.Type, "existingFakeType")
err = c.Update(ctx, NewTable("user_permissions", "id", "user_id", "perm_id"), &userPermission{
ID: existingPerm.ID,
UserID: 42,
PermID: 43,
Type: "newFakeType",
})
tt.AssertNoErr(t, err)
newPerm, err := getUserPermissionBySecondaryKeys(db, c.dialect, 42, 43)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, newPerm, userPermission{
ID: existingPerm.ID,
UserID: 42,
PermID: 43,
Type: "newFakeType",
})
})
t.Run("should ignore null pointers on partial updates", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
type partialUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
Age *int `ksql:"age"`
}
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Letícia', 22)`)
tt.AssertNoErr(t, err)
var u user
err = getUserByName(db, dialect, &u, "Letícia")
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, u.ID, uint(0))
err = c.Update(ctx, usersTable, partialUser{
ID: u.ID,
// Should be updated because it is not null, just empty:
Name: "",
// Should not be updated because it is null:
Age: nil,
})
tt.AssertNoErr(t, err)
var result user
err = getUserByID(c.db, c.dialect, &result, u.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, result.Name, "")
tt.AssertEqual(t, result.Age, 22)
})
t.Run("should update valid pointers on partial updates", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
type partialUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
Age *int `ksql:"age"`
}
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Letícia', 22)`)
tt.AssertNoErr(t, err)
var u user
err = getUserByName(db, dialect, &u, "Letícia")
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, u.ID, uint(0))
// Should update all fields:
err = c.Update(ctx, usersTable, partialUser{
ID: u.ID,
Name: "Thay",
Age: nullable.Int(42),
})
tt.AssertNoErr(t, err)
var result user
err = getUserByID(c.db, c.dialect, &result, u.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, result.Name, "Thay")
tt.AssertEqual(t, result.Age, 42)
})
t.Run("should return ErrRecordNotFound when asked to update an inexistent user", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
err = c.Update(ctx, usersTable, user{
ID: 4200,
Name: "Thayane",
})
tt.AssertEqual(t, err, ErrRecordNotFound)
})
t.Run("should report database errors correctly", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
err = c.Update(ctx, NewTable("non_existing_table"), user{
ID: 1,
Name: "Thayane",
})
tt.AssertNotEqual(t, err, nil)
})
t.Run("should report error if it receives a nil pointer to a struct", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
var u *user
err := c.Update(ctx, usersTable, u)
tt.AssertNotEqual(t, err, nil)
})
t.Run("should report error if the struct has no fields to update", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
err = c.Update(ctx, usersTable, struct {
ID uint `ksql:"id"` // ID fields are not updated
Name string `ksql:"name,skipUpdates"` // the skipUpdate modifier should rule this one out
Age *int `ksql:"age"` // Age is a nil pointer so it would not be updated
}{
ID: 1,
Name: "some name",
})
tt.AssertErrContains(t, err, "struct", "no values to update")
})
t.Run("should report error if the id is missing", func(t *testing.T) {
t.Run("with a single primary key", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
err := c.Update(ctx, usersTable, &user{
// Missing ID
Name: "Jane",
})
tt.AssertErrContains(t, err, "invalid value", "0", "'id'")
})
t.Run("with composite keys", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
err := c.Update(ctx, NewTable("user_permissions", "id", "user_id", "perm_id"), &userPermission{
ID: 1,
// Missing UserID
PermID: 42,
})
tt.AssertErrContains(t, err, "invalid value", "0", "'user_id'")
})
})
t.Run("should report error if the struct has no id", func(t *testing.T) {
t.Run("with a single primary key", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
err := c.Update(ctx, usersTable, &struct {
// Missing ID
Name string `ksql:"name"`
}{
Name: "Jane",
})
tt.AssertErrContains(t, err, "missing", "ID fields", "id")
})
t.Run("with composite keys", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
err := c.Update(ctx, NewTable("user_permissions", "id", "user_id", "perm_id"), &struct {
ID int `ksql:"id"`
// Missing UserID
PermID int `ksql:"perm_id"`
}{
ID: 1,
PermID: 42,
})
tt.AssertErrContains(t, err, "missing", "ID fields", "user_id")
})
})
t.Run("should report error context.Canceled the context has been canceled", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
ctx, cancel := context.WithCancel(ctx)
cancel()
c := newTestDB(db, dialect)
err = c.Update(ctx, usersTable, user{
ID: 1,
Name: "Thayane",
})
tt.AssertEqual(t, errors.Is(err, context.Canceled), true)
})
})
}
// QueryChunksTest runs all tests for making sure the QueryChunks function is
// working for a given adapter and dialect.
func QueryChunksTest(
t *testing.T,
dialect sqldialect.Provider,
connStr string,
newDBAdapter func(t *testing.T) (DBAdapter, io.Closer),
) {
ctx := context.Background()
t.Run("QueryChunks", func(t *testing.T) {
variations := []struct {
desc string
queryPrefix string
}{
{
desc: "with select *",
queryPrefix: "SELECT * ",
},
{
desc: "building the SELECT part of the query internally",
queryPrefix: "",
},
}
for _, variation := range variations {
t.Run(variation.desc, func(t *testing.T) {
t.Run("should query a single row correctly", func(t *testing.T) {
err := createTables(dialect, connStr)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
_ = c.Insert(ctx, usersTable, &user{
Name: "User1",
Address: address{Country: "BR"},
})
var length int
var u user
err = c.QueryChunks(ctx, ChunkParser{
Query: variation.queryPrefix + `FROM users WHERE name = ` + c.dialect.Placeholder(0),
Params: []interface{}{"User1"},
ChunkSize: 100,
ForEachChunk: func(users []user) error {
length = len(users)
if length > 0 {
u = users[0]
}
return nil
},
})
tt.AssertNoErr(t, err)
tt.AssertEqual(t, length, 1)
tt.AssertNotEqual(t, u.ID, uint(0))
tt.AssertEqual(t, u.Name, "User1")
tt.AssertEqual(t, u.Address.Country, "BR")
})
t.Run("should query one chunk correctly", func(t *testing.T) {
err := createTables(dialect, connStr)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
_ = c.Insert(ctx, usersTable, &user{Name: "User1", Address: address{Country: "US"}})
_ = c.Insert(ctx, usersTable, &user{Name: "User2", Address: address{Country: "BR"}})
var lengths []int
var users []user
err = c.QueryChunks(ctx, ChunkParser{
Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`,
Params: []interface{}{"User%"},
ChunkSize: 2,
ForEachChunk: func(buffer []user) error {
users = append(users, buffer...)
lengths = append(lengths, len(buffer))
return nil
},
})
tt.AssertNoErr(t, err)
tt.AssertEqual(t, len(lengths), 1)
tt.AssertEqual(t, lengths[0], 2)
tt.AssertNotEqual(t, users[0].ID, uint(0))
tt.AssertEqual(t, users[0].Name, "User1")
tt.AssertEqual(t, users[0].Address.Country, "US")
tt.AssertNotEqual(t, users[1].ID, uint(0))
tt.AssertEqual(t, users[1].Name, "User2")
tt.AssertEqual(t, users[1].Address.Country, "BR")
})
t.Run("should query chunks of 1 correctly", func(t *testing.T) {
err := createTables(dialect, connStr)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
_ = c.Insert(ctx, usersTable, &user{Name: "User1", Address: address{Country: "US"}})
_ = c.Insert(ctx, usersTable, &user{Name: "User2", Address: address{Country: "BR"}})
var lengths []int
var users []user
err = c.QueryChunks(ctx, ChunkParser{
Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`,
Params: []interface{}{"User%"},
ChunkSize: 1,
ForEachChunk: func(buffer []user) error {
lengths = append(lengths, len(buffer))
users = append(users, buffer...)
return nil
},
})
tt.AssertNoErr(t, err)
tt.AssertEqual(t, len(users), 2)
tt.AssertEqual(t, lengths, []int{1, 1})
tt.AssertNotEqual(t, users[0].ID, uint(0))
tt.AssertEqual(t, users[0].Name, "User1")
tt.AssertEqual(t, users[0].Address.Country, "US")
tt.AssertNotEqual(t, users[1].ID, uint(0))
tt.AssertEqual(t, users[1].Name, "User2")
tt.AssertEqual(t, users[1].Address.Country, "BR")
})
t.Run("should load partially filled chunks correctly", func(t *testing.T) {
err := createTables(dialect, connStr)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
_ = c.Insert(ctx, usersTable, &user{Name: "User1"})
_ = c.Insert(ctx, usersTable, &user{Name: "User2"})
_ = c.Insert(ctx, usersTable, &user{Name: "User3"})
var lengths []int
var users []user
err = c.QueryChunks(ctx, ChunkParser{
Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`,
Params: []interface{}{"User%"},
ChunkSize: 2,
ForEachChunk: func(buffer []user) error {
lengths = append(lengths, len(buffer))
users = append(users, buffer...)
return nil
},
})
tt.AssertNoErr(t, err)
tt.AssertEqual(t, len(users), 3)
tt.AssertNotEqual(t, users[0].ID, uint(0))
tt.AssertEqual(t, users[0].Name, "User1")
tt.AssertNotEqual(t, users[1].ID, uint(0))
tt.AssertEqual(t, users[1].Name, "User2")
tt.AssertNotEqual(t, users[2].ID, uint(0))
tt.AssertEqual(t, users[2].Name, "User3")
tt.AssertEqual(t, lengths, []int{2, 1})
})
// xxx
t.Run("should query joined tables correctly", func(t *testing.T) {
// This test only makes sense with no query prefix
if variation.queryPrefix != "" {
return
}
db, closer := newDBAdapter(t)
defer closer.Close()
joao := user{
Name: "Thiago Ribeiro",
Age: 24,
}
thatiana := user{
Name: "Thatiana Ribeiro",
Age: 20,
}
c := newTestDB(db, dialect)
_ = c.Insert(ctx, usersTable, &joao)
_ = c.Insert(ctx, usersTable, &thatiana)
_, err := db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, thatiana.ID, `, 'Thatiana Post1')`))
tt.AssertNoErr(t, err)
_, err = db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, thatiana.ID, `, 'Thatiana Post2')`))
tt.AssertNoErr(t, err)
_, err = db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, joao.ID, `, 'Thiago Post1')`))
tt.AssertNoErr(t, err)
var lengths []int
var users []user
var posts []post
err = c.QueryChunks(ctx, ChunkParser{
Query: fmt.Sprint(
`FROM users u JOIN posts p ON p.user_id = u.id`,
` WHERE u.name like `, c.dialect.Placeholder(0),
` ORDER BY u.id, p.id`,
),
Params: []interface{}{"% Ribeiro"},
ChunkSize: 2,
ForEachChunk: func(chunk []struct {
User user `tablename:"u"`
Post post `tablename:"p"`
}) error {
lengths = append(lengths, len(chunk))
for _, row := range chunk {
users = append(users, row.User)
posts = append(posts, row.Post)
}
return nil
},
})
tt.AssertNoErr(t, err)
tt.AssertEqual(t, len(posts), 3)
tt.AssertEqual(t, users[0].ID, joao.ID)
tt.AssertEqual(t, users[0].Name, "Thiago Ribeiro")
tt.AssertEqual(t, posts[0].Title, "Thiago Post1")
tt.AssertEqual(t, users[1].ID, thatiana.ID)
tt.AssertEqual(t, users[1].Name, "Thatiana Ribeiro")
tt.AssertEqual(t, posts[1].Title, "Thatiana Post1")
tt.AssertEqual(t, users[2].ID, thatiana.ID)
tt.AssertEqual(t, users[2].Name, "Thatiana Ribeiro")
tt.AssertEqual(t, posts[2].Title, "Thatiana Post2")
})
t.Run("should abort the first iteration when the callback returns an ErrAbortIteration", func(t *testing.T) {
err := createTables(dialect, connStr)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
_ = c.Insert(ctx, usersTable, &user{Name: "User1"})
_ = c.Insert(ctx, usersTable, &user{Name: "User2"})
_ = c.Insert(ctx, usersTable, &user{Name: "User3"})
var lengths []int
var users []user
err = c.QueryChunks(ctx, ChunkParser{
Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`,
Params: []interface{}{"User%"},
ChunkSize: 2,
ForEachChunk: func(buffer []user) error {
lengths = append(lengths, len(buffer))
users = append(users, buffer...)
return ErrAbortIteration
},
})
tt.AssertNoErr(t, err)
tt.AssertEqual(t, len(users), 2)
tt.AssertNotEqual(t, users[0].ID, uint(0))
tt.AssertEqual(t, users[0].Name, "User1")
tt.AssertNotEqual(t, users[1].ID, uint(0))
tt.AssertEqual(t, users[1].Name, "User2")
tt.AssertEqual(t, lengths, []int{2})
})
t.Run("should abort the last iteration when the callback returns an ErrAbortIteration", func(t *testing.T) {
err := createTables(dialect, connStr)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
_ = c.Insert(ctx, usersTable, &user{Name: "User1"})
_ = c.Insert(ctx, usersTable, &user{Name: "User2"})
_ = c.Insert(ctx, usersTable, &user{Name: "User3"})
returnVals := []error{nil, ErrAbortIteration}
var lengths []int
var users []user
err = c.QueryChunks(ctx, ChunkParser{
Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`,
Params: []interface{}{"User%"},
ChunkSize: 2,
ForEachChunk: func(buffer []user) error {
lengths = append(lengths, len(buffer))
users = append(users, buffer...)
return shiftErrSlice(&returnVals)
},
})
tt.AssertNoErr(t, err)
tt.AssertEqual(t, len(users), 3)
tt.AssertNotEqual(t, users[0].ID, uint(0))
tt.AssertEqual(t, users[0].Name, "User1")
tt.AssertNotEqual(t, users[1].ID, uint(0))
tt.AssertEqual(t, users[1].Name, "User2")
tt.AssertNotEqual(t, users[2].ID, uint(0))
tt.AssertEqual(t, users[2].Name, "User3")
tt.AssertEqual(t, lengths, []int{2, 1})
})
t.Run("should return error if the callback returns an error in the first iteration", func(t *testing.T) {
err := createTables(dialect, connStr)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
_ = c.Insert(ctx, usersTable, &user{Name: "User1"})
_ = c.Insert(ctx, usersTable, &user{Name: "User2"})
_ = c.Insert(ctx, usersTable, &user{Name: "User3"})
var lengths []int
var users []user
err = c.QueryChunks(ctx, ChunkParser{
Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`,
Params: []interface{}{"User%"},
ChunkSize: 2,
ForEachChunk: func(buffer []user) error {
lengths = append(lengths, len(buffer))
users = append(users, buffer...)
return fmt.Errorf("fake error msg")
},
})
tt.AssertNotEqual(t, err, nil)
tt.AssertEqual(t, len(users), 2)
tt.AssertNotEqual(t, users[0].ID, uint(0))
tt.AssertEqual(t, users[0].Name, "User1")
tt.AssertNotEqual(t, users[1].ID, uint(0))
tt.AssertEqual(t, users[1].Name, "User2")
tt.AssertEqual(t, lengths, []int{2})
})
t.Run("should return error if the callback returns an error in the last iteration", func(t *testing.T) {
err := createTables(dialect, connStr)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
_ = c.Insert(ctx, usersTable, &user{Name: "User1"})
_ = c.Insert(ctx, usersTable, &user{Name: "User2"})
_ = c.Insert(ctx, usersTable, &user{Name: "User3"})
returnVals := []error{nil, fmt.Errorf("fake error msg")}
var lengths []int
var users []user
err = c.QueryChunks(ctx, ChunkParser{
Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`,
Params: []interface{}{"User%"},
ChunkSize: 2,
ForEachChunk: func(buffer []user) error {
lengths = append(lengths, len(buffer))
users = append(users, buffer...)
return shiftErrSlice(&returnVals)
},
})
tt.AssertNotEqual(t, err, nil)
tt.AssertEqual(t, len(users), 3)
tt.AssertNotEqual(t, users[0].ID, uint(0))
tt.AssertEqual(t, users[0].Name, "User1")
tt.AssertNotEqual(t, users[1].ID, uint(0))
tt.AssertEqual(t, users[1].Name, "User2")
tt.AssertNotEqual(t, users[2].ID, uint(0))
tt.AssertEqual(t, users[2].Name, "User3")
tt.AssertEqual(t, lengths, []int{2, 1})
})
t.Run("should report error if the input function is invalid", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
funcs := []interface{}{
nil,
"not a function",
func() error {
return nil
},
func(extraInputValue []user, extra []user) error {
return nil
},
func(invalidArgType string) error {
return nil
},
func(missingReturnType []user) {
},
func(users []user) string {
return ""
},
func(extraReturnValue []user) ([]user, error) {
return nil, nil
},
func(notSliceOfStructs []string) error {
return nil
},
}
for _, fn := range funcs {
err := c.QueryChunks(ctx, ChunkParser{
Query: variation.queryPrefix + `FROM users`,
Params: []interface{}{},
ChunkSize: 2,
ForEachChunk: fn,
})
tt.AssertNotEqual(t, err, nil)
}
})
t.Run("should report error if the query is not valid", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
err := c.QueryChunks(ctx, ChunkParser{
Query: `SELECT * FROM not a valid query`,
Params: []interface{}{},
ChunkSize: 2,
ForEachChunk: func(buffer []user) error {
return nil
},
})
tt.AssertNotEqual(t, err, nil)
})
t.Run("should report error if using nested struct and the query starts with SELECT", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
err := c.QueryChunks(ctx, ChunkParser{
Query: `SELECT * FROM users u JOIN posts p ON u.id = p.user_id`,
Params: []interface{}{},
ChunkSize: 2,
ForEachChunk: func(buffer []struct {
User user `tablename:"users"`
Post post `tablename:"posts"`
}) error {
return nil
},
})
tt.AssertErrContains(t, err, "nested struct", "feature")
})
})
}
t.Run("error cases", func(t *testing.T) {
t.Run("should report error context.Canceled the context has been canceled", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
ctx, cancel := context.WithCancel(ctx)
cancel()
c := newTestDB(db, dialect)
err := c.QueryChunks(ctx, ChunkParser{
Query: `SELECT * FROM users u JOIN posts p ON u.id = p.user_id`,
Params: []interface{}{},
ChunkSize: 2,
ForEachChunk: func(rows []user) error {
return nil
},
})
tt.AssertEqual(t, errors.Is(err, context.Canceled), true)
})
})
})
}
// TransactionTest runs all tests for making sure the Transaction function is
// working for a given adapter and dialect.
func TransactionTest(
t *testing.T,
dialect sqldialect.Provider,
connStr string,
newDBAdapter func(t *testing.T) (DBAdapter, io.Closer),
) {
ctx := context.Background()
t.Run("Transaction", func(t *testing.T) {
t.Run("should query a single row correctly", func(t *testing.T) {
err := createTables(dialect, connStr)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
_ = c.Insert(ctx, usersTable, &user{Name: "User1"})
_ = c.Insert(ctx, usersTable, &user{Name: "User2"})
var users []user
err = c.Transaction(ctx, func(db Provider) error {
db.Query(ctx, &users, "SELECT * FROM users ORDER BY id ASC")
return nil
})
tt.AssertNoErr(t, err)
tt.AssertEqual(t, len(users), 2)
tt.AssertEqual(t, users[0].Name, "User1")
tt.AssertEqual(t, users[1].Name, "User2")
})
t.Run("should work normally in nested transactions", func(t *testing.T) {
err := createTables(dialect, connStr)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
u := user{
Name: "User1",
}
err = c.Insert(ctx, usersTable, &u)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, u.ID, 0)
var updatedUser user
err = c.Transaction(ctx, func(db Provider) error {
u.Age = 42
err = db.Patch(ctx, usersTable, &u)
if err != nil {
return err
}
return db.Transaction(ctx, func(db Provider) error {
return db.QueryOne(ctx, &updatedUser, "FROM users WHERE id = "+c.dialect.Placeholder(0), u.ID)
})
})
tt.AssertNoErr(t, err)
tt.AssertEqual(t, updatedUser.ID, u.ID)
tt.AssertEqual(t, updatedUser.Age, 42)
})
t.Run("should rollback when there are errors", func(t *testing.T) {
err := createTables(dialect, connStr)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
u1 := user{Name: "User1", Age: 42}
u2 := user{Name: "User2", Age: 42}
_ = c.Insert(ctx, usersTable, &u1)
_ = c.Insert(ctx, usersTable, &u2)
err = c.Transaction(ctx, func(db Provider) error {
err = db.Insert(ctx, usersTable, &user{Name: "User3"})
tt.AssertNoErr(t, err)
err = db.Insert(ctx, usersTable, &user{Name: "User4"})
tt.AssertNoErr(t, err)
_, err = db.Exec(ctx, "UPDATE users SET age = 22")
tt.AssertNoErr(t, err)
return fmt.Errorf("fake-error")
})
tt.AssertNotEqual(t, err, nil)
tt.AssertEqual(t, err.Error(), "fake-error")
var users []user
err = c.Query(ctx, &users, "SELECT * FROM users ORDER BY id ASC")
tt.AssertNoErr(t, err)
tt.AssertEqual(t, users, []user{u1, u2})
})
t.Run("should rollback when the fn call panics", func(t *testing.T) {
err := createTables(dialect, connStr)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
u1 := user{Name: "User1", Age: 42}
u2 := user{Name: "User2", Age: 42}
_ = c.Insert(ctx, usersTable, &u1)
_ = c.Insert(ctx, usersTable, &u2)
panicPayload := tt.PanicHandler(func() {
c.Transaction(ctx, func(db Provider) error {
err = db.Insert(ctx, usersTable, &user{Name: "User3"})
tt.AssertNoErr(t, err)
err = db.Insert(ctx, usersTable, &user{Name: "User4"})
tt.AssertNoErr(t, err)
_, err = db.Exec(ctx, "UPDATE users SET age = 22")
tt.AssertNoErr(t, err)
panic("fakePanicPayload")
})
})
tt.AssertEqual(t, panicPayload, "fakePanicPayload")
var users []user
err = c.Query(ctx, &users, "SELECT * FROM users ORDER BY id ASC")
tt.AssertNoErr(t, err)
tt.AssertEqual(t, users, []user{u1, u2})
})
t.Run("should handle rollback errors when the fn call panics", func(t *testing.T) {
err := createTables(dialect, connStr)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
cMock := mockTxBeginner{
DBAdapter: c.db,
BeginTxFn: func(ctx context.Context) (Tx, error) {
return mockTx{
DBAdapter: c.db,
RollbackFn: func(ctx context.Context) error {
return fmt.Errorf("fakeRollbackErrMsg")
},
}, nil
},
}
c.db = cMock
panicPayload := tt.PanicHandler(func() {
c.Transaction(ctx, func(db Provider) error {
panic("fakePanicPayload")
})
})
err, ok := panicPayload.(error)
tt.AssertEqual(t, ok, true)
tt.AssertErrContains(t, err, "fakePanicPayload", "fakeRollbackErrMsg")
})
t.Run("should handle rollback errors when fn returns an error", func(t *testing.T) {
err := createTables(dialect, connStr)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
cMock := mockTxBeginner{
DBAdapter: c.db,
BeginTxFn: func(ctx context.Context) (Tx, error) {
return mockTx{
DBAdapter: c.db,
RollbackFn: func(ctx context.Context) error {
return fmt.Errorf("fakeRollbackErrMsg")
},
}, nil
},
}
c.db = cMock
err = c.Transaction(ctx, func(db Provider) error {
return fmt.Errorf("fakeTransactionErrMsg")
})
tt.AssertErrContains(t, err, "fakeTransactionErrMsg", "fakeRollbackErrMsg")
})
t.Run("should report error when BeginTx() fails", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
cMock := mockTxBeginner{
DBAdapter: c.db,
BeginTxFn: func(ctx context.Context) (Tx, error) {
return nil, fmt.Errorf("fakeErrMsg")
},
}
c.db = cMock
err := c.Transaction(ctx, func(db Provider) error {
return nil
})
tt.AssertErrContains(t, err, "KSQL", "fakeErrMsg")
})
t.Run("should report error if DBAdapter can't create transactions", func(t *testing.T) {
err := createTables(dialect, connStr)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
c.db = mockDBAdapter{}
err = c.Transaction(ctx, func(db Provider) error {
return nil
})
tt.AssertErrContains(t, err, "KSQL", "can't start transaction", "DBAdapter", "TxBeginner")
})
t.Run("should report error context.Canceled the context has been canceled", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
ctx, cancel := context.WithCancel(ctx)
cancel()
c := newTestDB(db, dialect)
err := c.Transaction(ctx, func(db Provider) error {
return nil
})
tt.AssertEqual(t, errors.Is(err, context.Canceled), true)
})
})
}
func ModifiersTest(
t *testing.T,
dialect sqldialect.Provider,
connStr string,
newDBAdapter func(t *testing.T) (DBAdapter, io.Closer),
) {
ctx := context.Background()
t.Run("Modifiers", func(t *testing.T) {
err := createTables(dialect, connStr)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
t.Run("timeNowUTC modifier", func(t *testing.T) {
t.Run("should be set to time.Now().UTC() on insertion", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
type tsUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
UpdatedAt time.Time `ksql:"updated_at,timeNowUTC"`
}
u := tsUser{
Name: "Letícia",
}
err := c.Insert(ctx, usersTable, &u)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, u.ID, 0)
var untaggedUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
UpdatedAt time.Time `ksql:"updated_at"`
}
err = c.QueryOne(ctx, &untaggedUser, `FROM users WHERE id = `+c.dialect.Placeholder(0), u.ID)
tt.AssertNoErr(t, err)
now := time.Now()
tt.AssertApproxTime(t,
2*time.Second, untaggedUser.UpdatedAt, now,
"updatedAt should be set to %v, but got: %v", now, untaggedUser.UpdatedAt,
)
})
t.Run("should be set to time.Now().UTC() on updates", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
type userWithNoTags struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
UpdatedAt time.Time `ksql:"updated_at"`
}
untaggedUser := userWithNoTags{
Name: "Laura Ribeiro",
// Any time different from now:
UpdatedAt: tt.ParseTime(t, "2000-08-05T14:00:00Z"),
}
err := c.Insert(ctx, usersTable, &untaggedUser)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, untaggedUser.ID, 0)
type taggedUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
UpdatedAt time.Time `ksql:"updated_at,timeNowUTC"`
}
u := taggedUser{
ID: untaggedUser.ID,
Name: "Laurinha Ribeiro",
}
err = c.Patch(ctx, usersTable, u)
tt.AssertNoErr(t, err)
var untaggedUser2 userWithNoTags
err = c.QueryOne(ctx, &untaggedUser2, "FROM users WHERE id = "+c.dialect.Placeholder(0), u.ID)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, untaggedUser2.ID, 0)
now := time.Now()
tt.AssertApproxTime(t,
2*time.Second, untaggedUser2.UpdatedAt, now,
"updatedAt should be set to %v, but got: %v", now, untaggedUser2.UpdatedAt,
)
})
t.Run("should not alter the value on queries", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
type userWithNoTags struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
UpdatedAt time.Time `ksql:"updated_at"`
}
untaggedUser := userWithNoTags{
Name: "Marta Ribeiro",
// Any time different from now:
UpdatedAt: tt.ParseTime(t, "2000-08-05T14:00:00Z"),
}
err := c.Insert(ctx, usersTable, &untaggedUser)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, untaggedUser.ID, 0)
var taggedUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
UpdatedAt time.Time `ksql:"updated_at,timeNowUTC"`
}
err = c.QueryOne(ctx, &taggedUser, "FROM users WHERE id = "+c.dialect.Placeholder(0), untaggedUser.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, taggedUser.ID, untaggedUser.ID)
tt.AssertEqual(t, taggedUser.Name, "Marta Ribeiro")
tt.AssertEqual(t, taggedUser.UpdatedAt, tt.ParseTime(t, "2000-08-05T14:00:00Z"))
})
})
t.Run("timeNowUTC/skipUpdates modifier", func(t *testing.T) {
t.Run("should be set to time.Now().UTC() on insertion", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
type tsUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
CreatedAt time.Time `ksql:"created_at,timeNowUTC/skipUpdates"`
}
u := tsUser{
Name: "Letícia",
}
err := c.Insert(ctx, usersTable, &u)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, u.ID, 0)
var untaggedUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
CreatedAt time.Time `ksql:"created_at"`
}
err = c.QueryOne(ctx, &untaggedUser, `FROM users WHERE id = `+c.dialect.Placeholder(0), u.ID)
tt.AssertNoErr(t, err)
now := time.Now()
tt.AssertApproxTime(t,
2*time.Second, untaggedUser.CreatedAt, now,
"updatedAt should be set to %v, but got: %v", now, untaggedUser.CreatedAt,
)
})
t.Run("should be ignored on updates", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
type userWithNoTags struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
CreatedAt time.Time `ksql:"created_at"`
}
untaggedUser := userWithNoTags{
Name: "Laura Ribeiro",
// Any time different from now:
CreatedAt: tt.ParseTime(t, "2000-08-05T14:00:00Z"),
}
err := c.Insert(ctx, usersTable, &untaggedUser)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, untaggedUser.ID, 0)
type taggedUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
CreatedAt time.Time `ksql:"created_at,timeNowUTC/skipUpdates"`
}
u := taggedUser{
ID: untaggedUser.ID,
Name: "Laurinha Ribeiro",
// Some random time that should be ignored:
CreatedAt: tt.ParseTime(t, "1999-08-05T14:00:00Z"),
}
err = c.Patch(ctx, usersTable, u)
tt.AssertNoErr(t, err)
var untaggedUser2 userWithNoTags
err = c.QueryOne(ctx, &untaggedUser2, "FROM users WHERE id = "+c.dialect.Placeholder(0), u.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, untaggedUser2.CreatedAt, tt.ParseTime(t, "2000-08-05T14:00:00Z"))
})
t.Run("should not alter the value on queries", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
type userWithNoTags struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
CreatedAt time.Time `ksql:"created_at"`
}
untaggedUser := userWithNoTags{
Name: "Marta Ribeiro",
// Any time different from now:
CreatedAt: tt.ParseTime(t, "2000-08-05T14:00:00Z"),
}
err := c.Insert(ctx, usersTable, &untaggedUser)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, untaggedUser.ID, 0)
var taggedUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
CreatedAt time.Time `ksql:"created_at,timeNowUTC/skipUpdates"`
}
err = c.QueryOne(ctx, &taggedUser, "FROM users WHERE id = "+c.dialect.Placeholder(0), untaggedUser.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, taggedUser.ID, untaggedUser.ID)
tt.AssertEqual(t, taggedUser.Name, "Marta Ribeiro")
tt.AssertEqual(t, taggedUser.CreatedAt, tt.ParseTime(t, "2000-08-05T14:00:00Z"))
})
})
t.Run("skipInserts modifier", func(t *testing.T) {
t.Run("should ignore the field during insertions", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
type tsUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name,skipInserts"`
Age int `ksql:"age"`
}
u := tsUser{
Name: "Letícia",
Age: 22,
}
err := c.Insert(ctx, usersTable, &u)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, u.ID, 0)
var untaggedUser struct {
ID uint `ksql:"id"`
Name *string `ksql:"name"`
Age int `ksql:"age"`
}
err = c.QueryOne(ctx, &untaggedUser, `FROM users WHERE id = `+c.dialect.Placeholder(0), u.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, untaggedUser.Name, (*string)(nil))
tt.AssertEqual(t, untaggedUser.Age, 22)
})
t.Run("should have no effect on updates", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
type userWithNoTags struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
Age int `ksql:"age"`
}
untaggedUser := userWithNoTags{
Name: "Laurinha Ribeiro",
Age: 11,
}
err := c.Insert(ctx, usersTable, &untaggedUser)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, untaggedUser.ID, 0)
type taggedUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name,skipInserts"`
Age int `ksql:"age"`
}
u := taggedUser{
ID: untaggedUser.ID,
Name: "Laura Ribeiro",
Age: 12,
}
err = c.Patch(ctx, usersTable, u)
tt.AssertNoErr(t, err)
var untaggedUser2 userWithNoTags
err = c.QueryOne(ctx, &untaggedUser2, "FROM users WHERE id = "+c.dialect.Placeholder(0), u.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, untaggedUser2.Name, "Laura Ribeiro")
tt.AssertEqual(t, untaggedUser2.Age, 12)
})
t.Run("should not alter the value on queries", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
type userWithNoTags struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
}
untaggedUser := userWithNoTags{
Name: "Marta Ribeiro",
}
err := c.Insert(ctx, usersTable, &untaggedUser)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, untaggedUser.ID, 0)
var taggedUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name,skipInserts"`
}
err = c.QueryOne(ctx, &taggedUser, "FROM users WHERE id = "+c.dialect.Placeholder(0), untaggedUser.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, taggedUser.ID, untaggedUser.ID)
tt.AssertEqual(t, taggedUser.Name, "Marta Ribeiro")
})
})
t.Run("skipUpdates modifier", func(t *testing.T) {
t.Run("should set the field on insertion", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
type tsUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name,skipUpdates"`
}
u := tsUser{
Name: "Letícia",
}
err := c.Insert(ctx, usersTable, &u)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, u.ID, 0)
var untaggedUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
}
err = c.QueryOne(ctx, &untaggedUser, `FROM users WHERE id = `+c.dialect.Placeholder(0), u.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, untaggedUser.Name, "Letícia")
})
t.Run("should be ignored on updates", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
type userWithNoTags struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
Age int `ksql:"age"`
}
untaggedUser := userWithNoTags{
Name: "Laurinha Ribeiro",
Age: 11,
}
err := c.Insert(ctx, usersTable, &untaggedUser)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, untaggedUser.ID, 0)
type taggedUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name,skipUpdates"`
Age int `ksql:"age"`
}
u := taggedUser{
ID: untaggedUser.ID,
Name: "Laura Ribeiro",
Age: 12,
}
err = c.Patch(ctx, usersTable, u)
tt.AssertNoErr(t, err)
var untaggedUser2 userWithNoTags
err = c.QueryOne(ctx, &untaggedUser2, "FROM users WHERE id = "+c.dialect.Placeholder(0), u.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, untaggedUser2.Name, "Laurinha Ribeiro")
tt.AssertEqual(t, untaggedUser2.Age, 12)
})
t.Run("should not alter the value on queries", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
type userWithNoTags struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
}
untaggedUser := userWithNoTags{
Name: "Marta Ribeiro",
}
err := c.Insert(ctx, usersTable, &untaggedUser)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, untaggedUser.ID, 0)
var taggedUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name,skipUpdates"`
}
err = c.QueryOne(ctx, &taggedUser, "FROM users WHERE id = "+c.dialect.Placeholder(0), untaggedUser.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, taggedUser.ID, untaggedUser.ID)
tt.AssertEqual(t, taggedUser.Name, "Marta Ribeiro")
})
})
t.Run("nullable modifier", func(t *testing.T) {
t.Run("should prevent null fields from being ignored during insertions", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
// The default value of the column "nullable_field"
// is the string: "not_null".
//
// So the tagged struct below should insert passing NULL
// and the untagged should insert not passing any value
// for this column, thus, only the second one should create
// a recording using the default value.
var taggedUser struct {
ID uint `ksql:"id"`
NullableField *string `ksql:"nullable_field,nullable"`
}
var untaggedUser struct {
ID uint `ksql:"id"`
NullableField *string `ksql:"nullable_field"`
}
err := c.Insert(ctx, usersTable, &taggedUser)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, taggedUser.ID, 0)
err = c.Insert(ctx, usersTable, &untaggedUser)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, untaggedUser.ID, 0)
err = c.QueryOne(ctx, &taggedUser, `FROM users WHERE id = `+c.dialect.Placeholder(0), taggedUser.ID)
tt.AssertNoErr(t, err)
err = c.QueryOne(ctx, &untaggedUser, `FROM users WHERE id = `+c.dialect.Placeholder(0), untaggedUser.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, taggedUser.NullableField == nil, true)
tt.AssertEqual(t, untaggedUser.NullableField, nullable.String("not_null"))
})
t.Run("should prevent null fields from being ignored during updates", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
type userWithNoTags struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
NullableField *string `ksql:"nullable_field"`
}
untaggedUser := userWithNoTags{
Name: "Laurinha Ribeiro",
NullableField: nullable.String("fakeValue"),
}
err := c.Insert(ctx, usersTable, &untaggedUser)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, untaggedUser.ID, 0)
type taggedUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
NullableField *string `ksql:"nullable_field,nullable"`
}
u := taggedUser{
ID: untaggedUser.ID,
Name: "Laura Ribeiro",
NullableField: nil,
}
err = c.Patch(ctx, usersTable, u)
tt.AssertNoErr(t, err)
var untaggedUser2 userWithNoTags
err = c.QueryOne(ctx, &untaggedUser2, "FROM users WHERE id = "+c.dialect.Placeholder(0), untaggedUser.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, untaggedUser2.Name, "Laura Ribeiro")
tt.AssertEqual(t, untaggedUser2.NullableField == nil, true)
})
t.Run("should not alter the value on queries", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
type userWithNoTags struct {
ID uint `ksql:"id"`
Name *string `ksql:"name"`
}
untaggedUser := userWithNoTags{
Name: nullable.String("Marta Ribeiro"),
}
err := c.Insert(ctx, usersTable, &untaggedUser)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, untaggedUser.ID, 0)
var taggedUser struct {
ID uint `ksql:"id"`
Name *string `ksql:"name,nullable"`
}
err = c.QueryOne(ctx, &taggedUser, "FROM users WHERE id = "+c.dialect.Placeholder(0), untaggedUser.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, taggedUser.ID, untaggedUser.ID)
tt.AssertEqual(t, taggedUser.Name, nullable.String("Marta Ribeiro"))
})
t.Run("should cause no effect if used on a non pointer field", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
type user struct {
ID uint `ksql:"id"`
Name string `ksql:"name,nullable"`
Age int `ksql:"age,nullable"`
}
u1 := user{
Name: "Marta Ribeiro",
}
err := c.Insert(ctx, usersTable, &u1)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, u1.ID, 0)
err = c.Patch(ctx, usersTable, &struct {
ID uint `ksql:"id"`
Age int `ksql:"age,nullable"`
}{
ID: u1.ID,
Age: 42,
})
tt.AssertNoErr(t, err)
var u2 user
err = c.QueryOne(ctx, &u2, "FROM users WHERE id = "+c.dialect.Placeholder(0), u1.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, u2.ID, u1.ID)
tt.AssertEqual(t, u2.Name, "Marta Ribeiro")
tt.AssertEqual(t, u2.Age, 42)
})
})
})
}
// ScanRowsTest runs all tests for making sure the ScanRows feature is
// working for a given adapter and dialect.
func ScanRowsTest(
t *testing.T,
dialect sqldialect.Provider,
connStr string,
newDBAdapter func(t *testing.T) (DBAdapter, io.Closer),
) {
ctx := context.Background()
t.Run("ScanRows", func(t *testing.T) {
t.Run("should scan users correctly", func(t *testing.T) {
err := createTables(dialect, connStr)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
_ = c.Insert(ctx, usersTable, &user{Name: "User1", Age: 22})
_ = c.Insert(ctx, usersTable, &user{Name: "User2", Age: 14})
_ = c.Insert(ctx, usersTable, &user{Name: "User3", Age: 43})
rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User2'")
tt.AssertNoErr(t, err)
defer rows.Close()
tt.AssertEqual(t, rows.Next(), true)
var u user
err = scanRows(ctx, dialect, rows, &u)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, u.Name, "User2")
tt.AssertEqual(t, u.Age, 14)
})
t.Run("should ignore extra columns from query", func(t *testing.T) {
err := createTables(dialect, connStr)
tt.AssertNoErr(t, err)
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
_ = c.Insert(ctx, usersTable, &user{Name: "User1", Age: 22})
rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User1'")
tt.AssertNoErr(t, err)
defer rows.Close()
tt.AssertEqual(t, rows.Next(), true)
var u struct {
ID int `ksql:"id"`
Age int `ksql:"age"`
// Omitted for testing purposes:
// Name string `ksql:"name"`
}
err = scanRows(ctx, dialect, rows, &u)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, u.Age, 22)
})
t.Run("should report scan errors", func(t *testing.T) {
type brokenUser struct {
ID uint `ksql:"id"`
// The error will happen here, when scanning
// an integer into a attribute of type struct{}:
Age struct{} `ksql:"age"`
}
type brokenNestedStruct struct {
User struct {
ID uint `ksql:"id"`
// The error will happen here, when scanning
// an integer into a attribute of type struct:
Age struct{} `ksql:"age"`
} `tablename:"u"`
Post post `tablename:"p"`
}
tests := []struct {
desc string
query string
scanTarget interface{}
expectErrToContain []string
}{
{
desc: "with anonymous structs",
query: "FROM users WHERE name='User22'",
scanTarget: &struct {
ID uint `ksql:"id"`
// The error will happen here, when scanning
// an integer into a attribute of type struct{}:
Age struct{} `ksql:"age"`
}{},
expectErrToContain: []string{" .Age", "struct {}"},
},
{
desc: "with named structs",
query: "FROM users WHERE name='User22'",
scanTarget: &brokenUser{},
expectErrToContain: []string{"brokenUser.Age", "struct {}"},
},
{
desc: "with anonymous nested structs",
query: "FROM users u JOIN posts p ON u.id = p.user_id WHERE name='User22'",
scanTarget: &struct {
User struct {
ID uint `ksql:"id"`
// The error will happen here, when scanning
// an integer into a attribute of type struct:
Age struct{} `ksql:"age"`
} `tablename:"u"`
Post post `tablename:"p"`
}{},
expectErrToContain: []string{".User.Age", "struct {}"},
},
{
desc: "with named nested structs",
query: "FROM users u JOIN posts p ON u.id = p.user_id WHERE name='User22'",
scanTarget: &brokenNestedStruct{},
expectErrToContain: []string{"brokenNestedStruct.User.Age", "struct {}"},
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
err := createTables(dialect, connStr)
tt.AssertNoErr(t, err)
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, dialect)
u := user{Name: "User22", Age: 22}
_ = c.Insert(ctx, usersTable, &u)
_ = c.Insert(ctx, postsTable, &post{UserID: u.ID, Title: "FakeTitle"})
query := mustBuildSelectQuery(t, dialect, test.scanTarget, test.query)
rows, err := db.QueryContext(ctx, query)
tt.AssertNoErr(t, err)
defer rows.Close()
tt.AssertEqual(t, rows.Next(), true)
err = scanRows(ctx, dialect, rows, test.scanTarget)
tt.AssertErrContains(t, err, test.expectErrToContain...)
})
}
})
t.Run("should report error for closed rows", func(t *testing.T) {
err := createTables(dialect, connStr)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
db, closer := newDBAdapter(t)
defer closer.Close()
rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User2'")
tt.AssertNoErr(t, err)
var u user
err = rows.Close()
tt.AssertNoErr(t, err)
err = scanRows(ctx, dialect, rows, &u)
tt.AssertNotEqual(t, err, nil)
})
t.Run("should report if record is not a pointer", func(t *testing.T) {
err := createTables(dialect, connStr)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
db, closer := newDBAdapter(t)
defer closer.Close()
rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User2'")
tt.AssertNoErr(t, err)
defer rows.Close()
var u user
err = scanRows(ctx, dialect, rows, u)
tt.AssertErrContains(t, err, "ksql", "expected", "pointer to struct", "user")
})
t.Run("should report if record is not a pointer to struct", func(t *testing.T) {
err := createTables(dialect, connStr)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
db, closer := newDBAdapter(t)
defer closer.Close()
rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User2'")
tt.AssertNoErr(t, err)
defer rows.Close()
var u map[string]interface{}
err = scanRows(ctx, dialect, rows, &u)
tt.AssertErrContains(t, err, "KSQL", "expected", "pointer to struct", "map[string]interface")
})
})
}
func createTables(dialect sqldialect.Provider, connStr string) error {
driver := dialect.DriverName()
if connStr == "" {
return fmt.Errorf("unsupported dialect: '%s'", driver)
}
db, err := sql.Open(driver, connStr)
if err != nil {
return err
}
defer db.Close()
db.Exec(`DROP TABLE users`)
switch driver {
case "sqlite3":
_, err = db.Exec(`CREATE TABLE users (
id INTEGER PRIMARY KEY,
age INTEGER,
name TEXT,
address BLOB,
created_at DATETIME,
updated_at DATETIME,
nullable_field TEXT DEFAULT "not_null"
)`)
case "postgres":
_, err = db.Exec(`CREATE TABLE users (
id serial PRIMARY KEY,
age INT,
name VARCHAR(50),
address jsonb,
created_at TIMESTAMP,
updated_at TIMESTAMP,
nullable_field VARCHAR(50) DEFAULT 'not_null'
)`)
case "mysql":
_, err = db.Exec(`CREATE TABLE users (
id INT AUTO_INCREMENT PRIMARY KEY,
age INT,
name VARCHAR(50),
address JSON,
created_at DATETIME,
updated_at DATETIME,
nullable_field VARCHAR(50) DEFAULT "not_null"
)`)
case "sqlserver":
_, err = db.Exec(`CREATE TABLE users (
id INT IDENTITY(1,1) PRIMARY KEY,
age INT,
name VARCHAR(50),
address NVARCHAR(4000),
created_at DATETIME,
updated_at DATETIME,
nullable_field VARCHAR(50) DEFAULT 'not_null'
)`)
}
if err != nil {
return fmt.Errorf("failed to create new users table: %s", err.Error())
}
db.Exec(`DROP TABLE posts`)
switch driver {
case "sqlite3":
_, err = db.Exec(`CREATE TABLE posts (
id INTEGER PRIMARY KEY,
user_id INTEGER,
title TEXT
)`)
case "postgres":
_, err = db.Exec(`CREATE TABLE posts (
id serial PRIMARY KEY,
user_id INT,
title VARCHAR(50)
)`)
case "mysql":
_, err = db.Exec(`CREATE TABLE posts (
id INT AUTO_INCREMENT PRIMARY KEY,
user_id INT,
title VARCHAR(50)
)`)
case "sqlserver":
_, err = db.Exec(`CREATE TABLE posts (
id INT IDENTITY(1,1) PRIMARY KEY,
user_id INT,
title VARCHAR(50)
)`)
}
if err != nil {
return fmt.Errorf("failed to create new posts table: %s", err.Error())
}
db.Exec(`DROP TABLE user_permissions`)
switch driver {
case "sqlite3":
_, err = db.Exec(`CREATE TABLE user_permissions (
id INTEGER PRIMARY KEY,
user_id INTEGER,
perm_id INTEGER,
type TEXT,
UNIQUE (user_id, perm_id)
)`)
case "postgres":
_, err = db.Exec(`CREATE TABLE user_permissions (
id serial PRIMARY KEY,
user_id INT,
perm_id INT,
type VARCHAR(50),
UNIQUE (user_id, perm_id)
)`)
case "mysql":
_, err = db.Exec(`CREATE TABLE user_permissions (
id INT AUTO_INCREMENT PRIMARY KEY,
user_id INT,
perm_id INT,
type VARCHAR(50),
UNIQUE KEY (user_id, perm_id)
)`)
case "sqlserver":
_, err = db.Exec(`CREATE TABLE user_permissions (
id INT IDENTITY(1,1) PRIMARY KEY,
user_id INT,
perm_id INT,
type VARCHAR(50),
CONSTRAINT unique_1 UNIQUE (user_id, perm_id)
)`)
}
if err != nil {
return fmt.Errorf("failed to create new user_permissions table: %s", err.Error())
}
return nil
}
func newTestDB(db DBAdapter, dialect sqldialect.Provider) DB {
return DB{
dialect: dialect,
db: db,
}
}
func shiftErrSlice(errs *[]error) error {
err := (*errs)[0]
*errs = (*errs)[1:]
return err
}
func getUserByID(db DBAdapter, dialect sqldialect.Provider, result *user, id uint) error {
rows, err := db.QueryContext(context.TODO(), `SELECT id, name, age, address FROM users WHERE id=`+dialect.Placeholder(0), id)
if err != nil {
return err
}
defer rows.Close()
if rows.Next() == false {
if rows.Err() != nil {
return rows.Err()
}
return sql.ErrNoRows
}
modifier, _ := modifiers.LoadGlobalModifier("json")
value := modifiers.AttrScanWrapper{
Ctx: context.TODO(),
AttrPtr: &result.Address,
ScanFn: modifier.Scan,
OpInfo: ksqlmodifiers.OpInfo{
DriverName: dialect.DriverName(),
// We will not differentiate between Query, QueryOne and QueryChunks
// if we did this could lead users to make very strange modifiers
Method: "Query",
},
}
err = rows.Scan(&result.ID, &result.Name, &result.Age, &value)
if err != nil {
return err
}
return nil
}
func getUserByName(db DBAdapter, dialect sqldialect.Provider, result *user, name string) error {
rows, err := db.QueryContext(context.TODO(), `SELECT id, name, age, address FROM users WHERE name=`+dialect.Placeholder(0), name)
if err != nil {
return err
}
defer rows.Close()
if rows.Next() == false {
if rows.Err() != nil {
return rows.Err()
}
return sql.ErrNoRows
}
var rawAddr []byte
err = rows.Scan(&result.ID, &result.Name, &result.Age, &rawAddr)
if err != nil {
return err
}
if rawAddr == nil {
return nil
}
return json.Unmarshal(rawAddr, &result.Address)
}
func createUserPermission(db DBAdapter, dialect sqldialect.Provider, userPerm userPermission) error {
_, err := db.ExecContext(context.TODO(),
`INSERT INTO user_permissions (user_id, perm_id, type)
VALUES (`+dialect.Placeholder(0)+`, `+dialect.Placeholder(1)+`, `+dialect.Placeholder(2)+`)`,
userPerm.UserID, userPerm.PermID, userPerm.Type,
)
return err
}
func getUserPermissionBySecondaryKeys(db DBAdapter, dialect sqldialect.Provider, userID int, permID int) (userPermission, error) {
rows, err := db.QueryContext(context.TODO(),
`SELECT id, user_id, perm_id, type FROM user_permissions WHERE user_id=`+dialect.Placeholder(0)+` AND perm_id=`+dialect.Placeholder(1),
userID, permID,
)
if err != nil {
return userPermission{}, err
}
defer rows.Close()
if rows.Next() == false {
if rows.Err() != nil {
return userPermission{}, rows.Err()
}
return userPermission{}, sql.ErrNoRows
}
var result userPermission
err = rows.Scan(&result.ID, &result.UserID, &result.PermID, &result.Type)
if err != nil {
return userPermission{}, err
}
return result, nil
}
func getUserPermissionsByUser(db DBAdapter, dialect sqldialect.Provider, userID int) (results []userPermission, _ error) {
rows, err := db.QueryContext(context.TODO(),
`SELECT id, user_id, perm_id FROM user_permissions WHERE user_id=`+dialect.Placeholder(0),
userID,
)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var userPerm userPermission
err := rows.Scan(&userPerm.ID, &userPerm.UserID, &userPerm.PermID)
if err != nil {
return nil, err
}
results = append(results, userPerm)
}
if rows.Err() != nil {
return nil, rows.Err()
}
return results, nil
}
func mustBuildSelectQuery(t *testing.T,
dialect sqldialect.Provider,
record interface{},
query string,
) string {
if strings.HasPrefix(query, "SELECT") {
return query
}
structType := reflect.TypeOf(record).Elem()
structInfo, err := structs.GetTagInfo(structType)
tt.AssertNoErr(t, err)
selectPrefix, err := buildSelectQuery(dialect, structType, structInfo, selectQueryCache[dialect.DriverName()])
tt.AssertNoErr(t, err)
return selectPrefix + query
}