mirror of https://github.com/VinGarcia/ksql.git
3766 lines
108 KiB
Go
3766 lines
108 KiB
Go
package ksql
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"reflect"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/vingarcia/ksql/internal/modifiers"
|
|
"github.com/vingarcia/ksql/internal/structs"
|
|
tt "github.com/vingarcia/ksql/internal/testtools"
|
|
"github.com/vingarcia/ksql/ksqlmodifiers"
|
|
"github.com/vingarcia/ksql/nullable"
|
|
"github.com/vingarcia/ksql/sqldialect"
|
|
)
|
|
|
|
var usersTable = NewTable("users")
|
|
|
|
type user struct {
|
|
ID uint `ksql:"id"`
|
|
Name string `ksql:"name"`
|
|
Age int `ksql:"age"`
|
|
|
|
Address address `ksql:"address,json"`
|
|
|
|
// This attr has no ksql tag, thus, it should be ignored:
|
|
AttrThatShouldBeIgnored string
|
|
}
|
|
|
|
type address struct {
|
|
Street string `json:"street"`
|
|
Number string `json:"number"`
|
|
|
|
City string `json:"city"`
|
|
State string `json:"state"`
|
|
Country string `json:"country"`
|
|
}
|
|
|
|
var postsTable = NewTable("posts")
|
|
|
|
type post struct {
|
|
ID int `ksql:"id"`
|
|
UserID uint `ksql:"user_id"`
|
|
Title string `ksql:"title"`
|
|
}
|
|
|
|
var userPermissionsTable = NewTable("user_permissions", "user_id", "perm_id")
|
|
|
|
type userPermission struct {
|
|
ID int `ksql:"id"`
|
|
UserID int `ksql:"user_id"`
|
|
PermID int `ksql:"perm_id"`
|
|
Type string `ksql:"type"`
|
|
}
|
|
|
|
// RunTestsForAdapter will run all necessary tests for making sure
|
|
// a given adapter is working as expected.
|
|
//
|
|
// Optionally it is also possible to run each of these tests
|
|
// separatedly, which might be useful during the development
|
|
// of a new adapter.
|
|
func RunTestsForAdapter(
|
|
t *testing.T,
|
|
adapterName string,
|
|
dialect sqldialect.Provider,
|
|
connStr string,
|
|
newDBAdapter func(t *testing.T) (DBAdapter, io.Closer),
|
|
) {
|
|
t.Run(adapterName, func(t *testing.T) {
|
|
t.Run(dialect.DriverName(), func(t *testing.T) {
|
|
QueryTest(t, dialect, connStr, newDBAdapter)
|
|
QueryOneTest(t, dialect, connStr, newDBAdapter)
|
|
InsertTest(t, dialect, connStr, newDBAdapter)
|
|
DeleteTest(t, dialect, connStr, newDBAdapter)
|
|
PatchTest(t, dialect, connStr, newDBAdapter)
|
|
QueryChunksTest(t, dialect, connStr, newDBAdapter)
|
|
TransactionTest(t, dialect, connStr, newDBAdapter)
|
|
ModifiersTest(t, dialect, connStr, newDBAdapter)
|
|
ScanRowsTest(t, dialect, connStr, newDBAdapter)
|
|
})
|
|
})
|
|
}
|
|
|
|
// QueryTest runs all tests for making sure the Query function is
|
|
// working for a given adapter and dialect.
|
|
func QueryTest(
|
|
t *testing.T,
|
|
dialect sqldialect.Provider,
|
|
connStr string,
|
|
newDBAdapter func(t *testing.T) (DBAdapter, io.Closer),
|
|
) {
|
|
ctx := context.Background()
|
|
|
|
t.Run("Query", func(t *testing.T) {
|
|
variations := []struct {
|
|
desc string
|
|
queryPrefix string
|
|
}{
|
|
{
|
|
desc: "with select *",
|
|
queryPrefix: "SELECT * ",
|
|
},
|
|
{
|
|
desc: "building the SELECT part of the query internally",
|
|
queryPrefix: "",
|
|
},
|
|
}
|
|
for _, variation := range variations {
|
|
t.Run(variation.desc, func(t *testing.T) {
|
|
t.Run("using slice of structs", func(t *testing.T) {
|
|
err := createTables(dialect, connStr)
|
|
if err != nil {
|
|
t.Fatal("could not create test table!, reason:", err.Error())
|
|
}
|
|
|
|
t.Run("should return 0 results correctly", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
var users []user
|
|
err := c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE id=1;`)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, len(users), 0)
|
|
|
|
users = []user{}
|
|
err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE id=1;`)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, len(users), 0)
|
|
})
|
|
|
|
t.Run("should return a user correctly", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Bia', 0, '{"country":"BR"}')`)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
c := newTestDB(db, dialect)
|
|
var users []user
|
|
err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia")
|
|
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, len(users), 1)
|
|
tt.AssertNotEqual(t, users[0].ID, uint(0))
|
|
tt.AssertEqual(t, users[0].Name, "Bia")
|
|
tt.AssertEqual(t, users[0].Address.Country, "BR")
|
|
})
|
|
|
|
t.Run("should return multiple users correctly", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('João Garcia', 0, '{"country":"US"}')`)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
_, err = db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Bia Garcia', 0, '{"country":"BR"}')`)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
c := newTestDB(db, dialect)
|
|
var users []user
|
|
err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name like `+c.dialect.Placeholder(0), "% Garcia")
|
|
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, len(users), 2)
|
|
|
|
tt.AssertNotEqual(t, users[0].ID, uint(0))
|
|
tt.AssertEqual(t, users[0].Name, "João Garcia")
|
|
tt.AssertEqual(t, users[0].Address.Country, "US")
|
|
|
|
tt.AssertNotEqual(t, users[1].ID, uint(0))
|
|
tt.AssertEqual(t, users[1].Name, "Bia Garcia")
|
|
tt.AssertEqual(t, users[1].Address.Country, "BR")
|
|
})
|
|
|
|
t.Run("should query joined tables correctly", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
// This test only makes sense with no query prefix
|
|
if variation.queryPrefix != "" {
|
|
return
|
|
}
|
|
|
|
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('João Ribeiro', 0, '{"country":"US"}')`)
|
|
tt.AssertNoErr(t, err)
|
|
var joao user
|
|
getUserByName(db, dialect, &joao, "João Ribeiro")
|
|
tt.AssertNoErr(t, err)
|
|
|
|
_, err = db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Bia Ribeiro', 0, '{"country":"BR"}')`)
|
|
tt.AssertNoErr(t, err)
|
|
var bia user
|
|
getUserByName(db, dialect, &bia, "Bia Ribeiro")
|
|
|
|
_, err = db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, bia.ID, `, 'Bia Post1')`))
|
|
tt.AssertNoErr(t, err)
|
|
_, err = db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, bia.ID, `, 'Bia Post2')`))
|
|
tt.AssertNoErr(t, err)
|
|
_, err = db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, joao.ID, `, 'João Post1')`))
|
|
tt.AssertNoErr(t, err)
|
|
|
|
ctx := context.Background()
|
|
c := newTestDB(db, dialect)
|
|
var rows []struct {
|
|
User user `tablename:"u"`
|
|
Post post `tablename:"p"`
|
|
|
|
// This one has no ksql or tablename tag,
|
|
// so it should just be ignored to avoid strange
|
|
// unexpected errors:
|
|
ExtraStructThatShouldBeIgnored user
|
|
}
|
|
err = c.Query(ctx, &rows, fmt.Sprint(
|
|
`FROM users u JOIN posts p ON p.user_id = u.id`,
|
|
` WHERE u.name like `, c.dialect.Placeholder(0),
|
|
` ORDER BY u.id, p.id`,
|
|
), "% Ribeiro")
|
|
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, len(rows), 3)
|
|
|
|
tt.AssertEqual(t, rows[0].User.ID, joao.ID)
|
|
tt.AssertEqual(t, rows[0].User.Name, "João Ribeiro")
|
|
tt.AssertEqual(t, rows[0].Post.Title, "João Post1")
|
|
tt.AssertNotEqual(t, rows[0].Post.ID, 0)
|
|
|
|
tt.AssertEqual(t, rows[1].User.ID, bia.ID)
|
|
tt.AssertEqual(t, rows[1].User.Name, "Bia Ribeiro")
|
|
tt.AssertEqual(t, rows[1].Post.Title, "Bia Post1")
|
|
tt.AssertNotEqual(t, rows[1].Post.ID, 0)
|
|
|
|
tt.AssertEqual(t, rows[2].User.ID, bia.ID)
|
|
tt.AssertEqual(t, rows[2].User.Name, "Bia Ribeiro")
|
|
tt.AssertEqual(t, rows[2].Post.Title, "Bia Post2")
|
|
tt.AssertNotEqual(t, rows[2].Post.ID, 0)
|
|
})
|
|
})
|
|
|
|
t.Run("using slice of pointers to structs", func(t *testing.T) {
|
|
err := createTables(dialect, connStr)
|
|
if err != nil {
|
|
t.Fatal("could not create test table!, reason:", err.Error())
|
|
}
|
|
|
|
t.Run("should return 0 results correctly", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
var users []*user
|
|
err := c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE id=1;`)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, len(users), 0)
|
|
|
|
users = []*user{}
|
|
err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE id=1;`)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, len(users), 0)
|
|
})
|
|
|
|
t.Run("should return a user correctly", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Bia', 0, '{"country":"BR"}')`)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
c := newTestDB(db, dialect)
|
|
var users []*user
|
|
err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia")
|
|
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, len(users), 1)
|
|
tt.AssertNotEqual(t, users[0].ID, uint(0))
|
|
tt.AssertEqual(t, users[0].Name, "Bia")
|
|
tt.AssertEqual(t, users[0].Address.Country, "BR")
|
|
})
|
|
|
|
t.Run("should return multiple users correctly", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('João Garcia', 0, '{"country":"US"}')`)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
_, err = db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Bia Garcia', 0, '{"country":"BR"}')`)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
c := newTestDB(db, dialect)
|
|
var users []*user
|
|
err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name like `+c.dialect.Placeholder(0), "% Garcia")
|
|
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, len(users), 2)
|
|
|
|
tt.AssertNotEqual(t, users[0].ID, uint(0))
|
|
tt.AssertEqual(t, users[0].Name, "João Garcia")
|
|
tt.AssertEqual(t, users[0].Address.Country, "US")
|
|
|
|
tt.AssertNotEqual(t, users[1].ID, uint(0))
|
|
tt.AssertEqual(t, users[1].Name, "Bia Garcia")
|
|
tt.AssertEqual(t, users[1].Address.Country, "BR")
|
|
})
|
|
|
|
t.Run("should query joined tables correctly", func(t *testing.T) {
|
|
// This test only makes sense with no query prefix
|
|
if variation.queryPrefix != "" {
|
|
return
|
|
}
|
|
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('João Ribeiro', 0, '{"country":"US"}')`)
|
|
tt.AssertNoErr(t, err)
|
|
var joao user
|
|
getUserByName(db, dialect, &joao, "João Ribeiro")
|
|
|
|
_, err = db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Bia Ribeiro', 0, '{"country":"BR"}')`)
|
|
tt.AssertNoErr(t, err)
|
|
var bia user
|
|
getUserByName(db, dialect, &bia, "Bia Ribeiro")
|
|
|
|
_, err = db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, bia.ID, `, 'Bia Post1')`))
|
|
tt.AssertNoErr(t, err)
|
|
_, err = db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, bia.ID, `, 'Bia Post2')`))
|
|
tt.AssertNoErr(t, err)
|
|
_, err = db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, joao.ID, `, 'João Post1')`))
|
|
tt.AssertNoErr(t, err)
|
|
|
|
c := newTestDB(db, dialect)
|
|
var rows []*struct {
|
|
User user `tablename:"u"`
|
|
Post post `tablename:"p"`
|
|
}
|
|
err = c.Query(ctx, &rows, fmt.Sprint(
|
|
`FROM users u JOIN posts p ON p.user_id = u.id`,
|
|
` WHERE u.name like `, c.dialect.Placeholder(0),
|
|
` ORDER BY u.id, p.id`,
|
|
), "% Ribeiro")
|
|
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, len(rows), 3)
|
|
|
|
tt.AssertEqual(t, rows[0].User.ID, joao.ID)
|
|
tt.AssertEqual(t, rows[0].User.Name, "João Ribeiro")
|
|
tt.AssertEqual(t, rows[0].Post.Title, "João Post1")
|
|
|
|
tt.AssertEqual(t, rows[1].User.ID, bia.ID)
|
|
tt.AssertEqual(t, rows[1].User.Name, "Bia Ribeiro")
|
|
tt.AssertEqual(t, rows[1].Post.Title, "Bia Post1")
|
|
|
|
tt.AssertEqual(t, rows[2].User.ID, bia.ID)
|
|
tt.AssertEqual(t, rows[2].User.Name, "Bia Ribeiro")
|
|
tt.AssertEqual(t, rows[2].Post.Title, "Bia Post2")
|
|
})
|
|
})
|
|
})
|
|
}
|
|
|
|
t.Run("testing error cases", func(t *testing.T) {
|
|
err := createTables(dialect, connStr)
|
|
if err != nil {
|
|
t.Fatal("could not create test table!, reason:", err.Error())
|
|
}
|
|
|
|
t.Run("should report error if input is not a pointer to a slice of structs", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Andréa Sá', 0)`)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
_, err = db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Caio Sá', 0)`)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
c := newTestDB(db, dialect)
|
|
err = c.Query(ctx, &user{}, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Sá")
|
|
tt.AssertErrContains(t, err, "expected", "to be a slice", "user")
|
|
|
|
err = c.Query(ctx, []*user{}, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Sá")
|
|
tt.AssertErrContains(t, err, "expected", "slice of structs", "user")
|
|
|
|
var i int
|
|
err = c.Query(ctx, &i, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Sá")
|
|
tt.AssertErrContains(t, err, "expected", "to be a slice", "int")
|
|
|
|
err = c.Query(ctx, &[]int{}, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Sá")
|
|
tt.AssertErrContains(t, err, "expected", "slice of structs", "[]int")
|
|
})
|
|
|
|
t.Run("should report error if the query is not valid", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
var users []user
|
|
err := c.Query(ctx, &users, `SELECT * FROM not a valid query`)
|
|
tt.AssertErrContains(t, err, "error running query")
|
|
})
|
|
|
|
t.Run("should report error if the TagInfoCache returns an error", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
// Provoque an error by sending an invalid struct:
|
|
var users []struct {
|
|
ID int `ksql:"id"`
|
|
// Private names cannot have ksql tags:
|
|
badPrivateField string `ksql:"name"`
|
|
}
|
|
err := c.Query(ctx, &users, `SELECT * FROM users`)
|
|
tt.AssertErrContains(t, err, "badPrivateField")
|
|
})
|
|
|
|
t.Run("should report error if using nested struct and the query starts with SELECT", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
var rows []struct {
|
|
User user `tablename:"users"`
|
|
Post post `tablename:"posts"`
|
|
}
|
|
err := c.Query(ctx, &rows, `SELECT * FROM users u JOIN posts p ON u.id = p.user_id`)
|
|
tt.AssertErrContains(t, err, "nested struct", "feature")
|
|
})
|
|
|
|
t.Run("should report error for nested structs with invalid types", func(t *testing.T) {
|
|
t.Run("int", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
var rows []struct {
|
|
Foo int `tablename:"foo"`
|
|
}
|
|
err := c.Query(ctx, &rows, fmt.Sprint(
|
|
`FROM users u JOIN posts p ON p.user_id = u.id`,
|
|
` WHERE u.name like `, c.dialect.Placeholder(0),
|
|
` ORDER BY u.id, p.id`,
|
|
), "% Ribeiro")
|
|
|
|
tt.AssertErrContains(t, err, "foo", "int")
|
|
})
|
|
|
|
t.Run("*struct", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
var rows []struct {
|
|
Foo *user `tablename:"foo"`
|
|
}
|
|
err := c.Query(ctx, &rows, fmt.Sprint(
|
|
`FROM users u JOIN posts p ON p.user_id = u.id`,
|
|
` WHERE u.name like `, c.dialect.Placeholder(0),
|
|
` ORDER BY u.id, p.id`,
|
|
), "% Ribeiro")
|
|
|
|
tt.AssertErrContains(t, err, "foo", "*ksql.user")
|
|
})
|
|
})
|
|
|
|
t.Run("should report error if nested struct is invalid", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
var rows []struct {
|
|
User user `tablename:"users"`
|
|
Post struct {
|
|
Attr1 int `ksql:"invalid_repeated_name"`
|
|
Attr2 int `ksql:"invalid_repeated_name"`
|
|
} `tablename:"posts"`
|
|
}
|
|
err := c.Query(ctx, &rows, `FROM users u JOIN posts p ON u.id = p.user_id`)
|
|
tt.AssertErrContains(t, err, "same ksql tag name", "invalid_repeated_name")
|
|
})
|
|
|
|
t.Run("should report error if DBAdapter.Scan() returns an error", func(t *testing.T) {
|
|
db := mockDBAdapter{
|
|
QueryContextFn: func(ctx context.Context, query string, args ...interface{}) (Rows, error) {
|
|
return mockRows{
|
|
ColumnsFn: func() ([]string, error) {
|
|
return []string{"id", "name", "age", "address"}, nil
|
|
},
|
|
NextFn: func() bool { return true },
|
|
ScanFn: func(values ...interface{}) error {
|
|
return fmt.Errorf("fakeScanErr")
|
|
},
|
|
}, nil
|
|
},
|
|
}
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
var users []user
|
|
err := c.Query(ctx, &users, `SELECT * FROM users`)
|
|
tt.AssertErrContains(t, err, "KSQL", "scan error", "fakeScanErr")
|
|
})
|
|
|
|
t.Run("should report error if DBAdapter.Err() returns an error", func(t *testing.T) {
|
|
db := mockDBAdapter{
|
|
QueryContextFn: func(ctx context.Context, query string, args ...interface{}) (Rows, error) {
|
|
return mockRows{
|
|
ColumnsFn: func() ([]string, error) {
|
|
return []string{"id", "name", "age", "address"}, nil
|
|
},
|
|
NextFn: func() bool { return false },
|
|
ScanFn: func(values ...interface{}) error {
|
|
return nil
|
|
},
|
|
ErrFn: func() error {
|
|
return fmt.Errorf("fakeErrMsg")
|
|
},
|
|
}, nil
|
|
},
|
|
}
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
var users []user
|
|
err := c.Query(ctx, &users, `SELECT * FROM users`)
|
|
tt.AssertErrContains(t, err, "KSQL", "fakeErrMsg")
|
|
})
|
|
|
|
t.Run("should report error if DBAdapter.Close() returns an error", func(t *testing.T) {
|
|
db := mockDBAdapter{
|
|
QueryContextFn: func(ctx context.Context, query string, args ...interface{}) (Rows, error) {
|
|
return mockRows{
|
|
ColumnsFn: func() ([]string, error) {
|
|
return []string{"id", "name", "age", "address"}, nil
|
|
},
|
|
NextFn: func() bool { return false },
|
|
ScanFn: func(values ...interface{}) error {
|
|
return nil
|
|
},
|
|
CloseFn: func() error {
|
|
return fmt.Errorf("fakeCloseErr")
|
|
},
|
|
}, nil
|
|
},
|
|
}
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
var users []user
|
|
err := c.Query(ctx, &users, `SELECT * FROM users`)
|
|
tt.AssertErrContains(t, err, "KSQL", "fakeCloseErr")
|
|
})
|
|
|
|
t.Run("should report error context.Canceled if the context is canceled", func(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
cancel()
|
|
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
var users []user
|
|
err := c.Query(ctx, &users, `SELECT * FROM users`)
|
|
tt.AssertErrContains(t, err, "context", "canceled")
|
|
tt.AssertEqual(t, errors.Is(err, context.Canceled), true)
|
|
})
|
|
})
|
|
})
|
|
}
|
|
|
|
// QueryOneTest runs all tests for making sure the QueryOne function is
|
|
// working for a given adapter and dialect.
|
|
func QueryOneTest(
|
|
t *testing.T,
|
|
dialect sqldialect.Provider,
|
|
connStr string,
|
|
newDBAdapter func(t *testing.T) (DBAdapter, io.Closer),
|
|
) {
|
|
ctx := context.Background()
|
|
|
|
t.Run("QueryOne", func(t *testing.T) {
|
|
variations := []struct {
|
|
desc string
|
|
queryPrefix string
|
|
}{
|
|
{
|
|
desc: "with select *",
|
|
queryPrefix: "SELECT * ",
|
|
},
|
|
{
|
|
desc: "building the SELECT part of the query internally",
|
|
queryPrefix: "",
|
|
},
|
|
}
|
|
for _, variation := range variations {
|
|
err := createTables(dialect, connStr)
|
|
if err != nil {
|
|
t.Fatal("could not create test table!, reason:", err.Error())
|
|
}
|
|
|
|
t.Run(variation.desc, func(t *testing.T) {
|
|
t.Run("should return RecordNotFoundErr when there are no results", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
u := user{}
|
|
err := c.QueryOne(ctx, &u, variation.queryPrefix+`FROM users WHERE id=1;`)
|
|
tt.AssertEqual(t, err, ErrRecordNotFound)
|
|
})
|
|
|
|
t.Run("should return a user correctly", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Bia', 0, '{"country":"BR"}')`)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
c := newTestDB(db, dialect)
|
|
u := user{}
|
|
err = c.QueryOne(ctx, &u, variation.queryPrefix+`FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia")
|
|
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, u.ID, uint(0))
|
|
tt.AssertEqual(t, u.Name, "Bia")
|
|
tt.AssertEqual(t, u.Address, address{
|
|
Country: "BR",
|
|
})
|
|
})
|
|
|
|
t.Run("should return only the first result on multiples matches", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Andréa Sá', 0, '{"country":"US"}')`)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
_, err = db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Caio Sá', 0, '{"country":"BR"}')`)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
var u user
|
|
err = c.QueryOne(ctx, &u, variation.queryPrefix+`FROM users WHERE name like `+c.dialect.Placeholder(0)+` ORDER BY id ASC`, "% Sá")
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, u.Name, "Andréa Sá")
|
|
tt.AssertEqual(t, u.Age, 0)
|
|
tt.AssertEqual(t, u.Address, address{
|
|
Country: "US",
|
|
})
|
|
})
|
|
|
|
t.Run("should query joined tables correctly", func(t *testing.T) {
|
|
// This test only makes sense with no query prefix
|
|
if variation.queryPrefix != "" {
|
|
return
|
|
}
|
|
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('João Ribeiro', 0, '{"country":"US"}')`)
|
|
tt.AssertNoErr(t, err)
|
|
var joao user
|
|
getUserByName(db, dialect, &joao, "João Ribeiro")
|
|
|
|
_, err = db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, joao.ID, `, 'João Post1')`))
|
|
tt.AssertNoErr(t, err)
|
|
|
|
c := newTestDB(db, dialect)
|
|
var row struct {
|
|
User user `tablename:"u"`
|
|
Post post `tablename:"p"`
|
|
}
|
|
err = c.QueryOne(ctx, &row, fmt.Sprint(
|
|
`FROM users u JOIN posts p ON p.user_id = u.id`,
|
|
` WHERE u.name like `, c.dialect.Placeholder(0),
|
|
` ORDER BY u.id, p.id`,
|
|
), "% Ribeiro")
|
|
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, row.User.ID, joao.ID)
|
|
tt.AssertEqual(t, row.User.Name, "João Ribeiro")
|
|
tt.AssertEqual(t, row.Post.Title, "João Post1")
|
|
})
|
|
|
|
t.Run("should handle column tags as case-insensitive as SQL does", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Count Olivia', 0, '{"country":"US"}')`)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
var row struct {
|
|
Count int `ksql:"myCount"`
|
|
}
|
|
err = c.QueryOne(ctx, &row, `SELECT count(*) as myCount FROM users WHERE name='Count Olivia'`)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, row.Count, 1)
|
|
})
|
|
})
|
|
}
|
|
|
|
t.Run("should report error if input is not a pointer to struct", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Andréa Sá', 0, '{"country":"US"}')`)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
_, err = db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Caio Sá', 0, '{"country":"BR"}')`)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err = c.QueryOne(ctx, &[]user{}, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Sá")
|
|
tt.AssertErrContains(t, err, "pointer to struct")
|
|
|
|
err = c.QueryOne(ctx, user{}, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Sá")
|
|
tt.AssertErrContains(t, err, "pointer to struct")
|
|
})
|
|
|
|
t.Run("should report error if it receives a nil pointer to a struct", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
var u *user
|
|
err := c.QueryOne(ctx, u, `SELECT * FROM users`)
|
|
tt.AssertErrContains(t, err, "expected a valid pointer", "received a nil pointer")
|
|
})
|
|
|
|
t.Run("should report error if the query is not valid", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
var u user
|
|
err := c.QueryOne(ctx, &u, `SELECT * FROM not a valid query`)
|
|
tt.AssertErrContains(t, err, "error running query")
|
|
})
|
|
|
|
t.Run("should report error if using nested struct and the query starts with SELECT", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
var row struct {
|
|
User user `tablename:"users"`
|
|
Post post `tablename:"posts"`
|
|
}
|
|
err := c.QueryOne(ctx, &row, `SELECT * FROM users u JOIN posts p ON u.id = p.user_id LIMIT 1`)
|
|
tt.AssertErrContains(t, err, "nested struct", "feature")
|
|
})
|
|
|
|
t.Run("should report error if a private field has a ksql tag", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Olivia', 0, '{"country":"US"}')`)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
var row struct {
|
|
count int `ksql:"my_count"`
|
|
}
|
|
err = c.QueryOne(ctx, &row, `SELECT count(*) as my_count FROM users`)
|
|
tt.AssertErrContains(t, err, "unexported", "my_count")
|
|
})
|
|
|
|
t.Run("should report error context.Canceled the context has been canceled", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
cancel()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
var u user
|
|
err := c.QueryOne(ctx, &u, `SELECT * FROM users LIMIT 1`)
|
|
tt.AssertEqual(t, errors.Is(err, context.Canceled), true)
|
|
})
|
|
})
|
|
}
|
|
|
|
// InsertTest runs all tests for making sure the Insert function is
|
|
// working for a given adapter and dialect.
|
|
func InsertTest(
|
|
t *testing.T,
|
|
dialect sqldialect.Provider,
|
|
connStr string,
|
|
newDBAdapter func(t *testing.T) (DBAdapter, io.Closer),
|
|
) {
|
|
ctx := context.Background()
|
|
|
|
t.Run("Insert", func(t *testing.T) {
|
|
t.Run("success cases", func(t *testing.T) {
|
|
t.Run("single primary key tables", func(t *testing.T) {
|
|
err := createTables(dialect, connStr)
|
|
if err != nil {
|
|
t.Fatal("could not create test table!, reason:", err.Error())
|
|
}
|
|
|
|
t.Run("should insert one user correctly", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
u := user{
|
|
Name: "Fernanda",
|
|
Address: address{
|
|
Country: "Brazil",
|
|
},
|
|
}
|
|
|
|
err := c.Insert(ctx, usersTable, &u)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, u.ID, 0)
|
|
|
|
result := user{}
|
|
err = getUserByID(c.db, c.dialect, &result, u.ID)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
tt.AssertEqual(t, result.Name, u.Name)
|
|
tt.AssertEqual(t, result.Address, u.Address)
|
|
})
|
|
|
|
t.Run("should insert ignoring the ID with multiple ids", func(t *testing.T) {
|
|
if dialect.InsertMethod() != sqldialect.InsertWithLastInsertID {
|
|
return
|
|
}
|
|
|
|
// Using columns "id" and "name" as IDs:
|
|
table := NewTable("users", "id", "name")
|
|
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
u := user{
|
|
Name: "No ID returned",
|
|
Age: 3434, // Random number to avoid false positives on this test
|
|
|
|
Address: address{
|
|
Country: "Brazil 3434",
|
|
},
|
|
}
|
|
|
|
err = c.Insert(ctx, table, &u)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, u.ID, uint(0))
|
|
|
|
result := user{}
|
|
err = getUserByName(c.db, dialect, &result, "No ID returned")
|
|
tt.AssertNoErr(t, err)
|
|
|
|
tt.AssertEqual(t, result.Age, u.Age)
|
|
tt.AssertEqual(t, result.Address, u.Address)
|
|
})
|
|
|
|
t.Run("should work with anonymous structs", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
err = c.Insert(ctx, usersTable, &struct {
|
|
ID int `ksql:"id"`
|
|
Name string `ksql:"name"`
|
|
Address map[string]interface{} `ksql:"address,json"`
|
|
}{Name: "fake-name", Address: map[string]interface{}{"city": "bar"}})
|
|
tt.AssertNoErr(t, err)
|
|
})
|
|
|
|
t.Run("should work with preset IDs", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
usersByName := NewTable("users", "name")
|
|
|
|
err = c.Insert(ctx, usersByName, &struct {
|
|
Name string `ksql:"name"`
|
|
Age int `ksql:"age"`
|
|
}{Name: "Preset Name", Age: 5455})
|
|
tt.AssertNoErr(t, err)
|
|
|
|
var inserted user
|
|
err := getUserByName(db, dialect, &inserted, "Preset Name")
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, inserted.Age, 5455)
|
|
})
|
|
|
|
t.Run("should work and retrieve the ID for structs with no attributes", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
type tsUser struct {
|
|
ID uint `ksql:"id"`
|
|
Name string `ksql:"name,skipInserts"`
|
|
}
|
|
u := tsUser{
|
|
Name: "Letícia",
|
|
}
|
|
err := c.Insert(ctx, usersTable, &u)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, u.ID, 0)
|
|
|
|
var untaggedUser struct {
|
|
ID uint `ksql:"id"`
|
|
Name *string `ksql:"name"`
|
|
}
|
|
err = c.QueryOne(ctx, &untaggedUser, `FROM users WHERE id = `+c.dialect.Placeholder(0), u.ID)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, untaggedUser.Name, (*string)(nil))
|
|
})
|
|
})
|
|
|
|
t.Run("composite key tables", func(t *testing.T) {
|
|
err := createTables(dialect, connStr)
|
|
if err != nil {
|
|
t.Fatal("could not create test table!, reason:", err.Error())
|
|
}
|
|
|
|
t.Run("should insert in composite key tables correctly", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
table := NewTable("user_permissions", "id", "user_id", "perm_id")
|
|
err = c.Insert(ctx, table, &userPermission{
|
|
UserID: 1,
|
|
PermID: 42,
|
|
})
|
|
tt.AssertNoErr(t, err)
|
|
|
|
userPerms, err := getUserPermissionsByUser(db, dialect, 1)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, len(userPerms), 1)
|
|
tt.AssertEqual(t, userPerms[0].UserID, 1)
|
|
tt.AssertEqual(t, userPerms[0].PermID, 42)
|
|
})
|
|
|
|
t.Run("should accept partially provided values for composite key tables", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
// Table defined with 3 values, but we'll provide only 2,
|
|
// the third will be generated for the purposes of this test:
|
|
table := NewTable("user_permissions", "id", "user_id", "perm_id")
|
|
permission := userPermission{
|
|
UserID: 2,
|
|
PermID: 42,
|
|
}
|
|
err = c.Insert(ctx, table, &permission)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
userPerms, err := getUserPermissionsByUser(db, dialect, 2)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
// Should retrieve the generated ID from the database,
|
|
// only if the database supports returning multiple values:
|
|
switch c.dialect.InsertMethod() {
|
|
case sqldialect.InsertWithNoIDRetrieval, sqldialect.InsertWithLastInsertID:
|
|
tt.AssertEqual(t, permission.ID, 0)
|
|
tt.AssertEqual(t, len(userPerms), 1)
|
|
tt.AssertEqual(t, userPerms[0].UserID, 2)
|
|
tt.AssertEqual(t, userPerms[0].PermID, 42)
|
|
case sqldialect.InsertWithReturning, sqldialect.InsertWithOutput:
|
|
tt.AssertNotEqual(t, permission.ID, 0)
|
|
tt.AssertEqual(t, len(userPerms), 1)
|
|
tt.AssertEqual(t, userPerms[0].ID, permission.ID)
|
|
tt.AssertEqual(t, userPerms[0].UserID, 2)
|
|
tt.AssertEqual(t, userPerms[0].PermID, 42)
|
|
}
|
|
})
|
|
|
|
t.Run("when inserting a struct with no values but composite keys should still retrieve the IDs", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
// Table defined with 3 values, but we'll provide only 2,
|
|
// the third will be generated for the purposes of this test:
|
|
table := NewTable("user_permissions", "id", "user_id", "perm_id")
|
|
type taggedPerm struct {
|
|
ID uint `ksql:"id"`
|
|
UserID int `ksql:"user_id"`
|
|
PermID int `ksql:"perm_id"`
|
|
Type string `ksql:"type,skipInserts"`
|
|
}
|
|
permission := taggedPerm{
|
|
UserID: 3,
|
|
PermID: 43,
|
|
}
|
|
err := c.Insert(ctx, table, &permission)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, permission.ID, 0)
|
|
|
|
var untaggedPerm struct {
|
|
ID uint `ksql:"id"`
|
|
UserID int `ksql:"user_id"`
|
|
PermID int `ksql:"perm_id"`
|
|
Type *string `ksql:"type"`
|
|
}
|
|
err = c.QueryOne(ctx, &untaggedPerm, `FROM user_permissions WHERE user_id = 3 AND perm_id = 43`)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, untaggedPerm.Type, (*string)(nil))
|
|
|
|
// Should retrieve the generated ID from the database,
|
|
// only if the database supports returning multiple values:
|
|
switch c.dialect.InsertMethod() {
|
|
case sqldialect.InsertWithNoIDRetrieval, sqldialect.InsertWithLastInsertID:
|
|
tt.AssertEqual(t, permission.ID, uint(0))
|
|
tt.AssertEqual(t, untaggedPerm.UserID, 3)
|
|
tt.AssertEqual(t, untaggedPerm.PermID, 43)
|
|
case sqldialect.InsertWithReturning, sqldialect.InsertWithOutput:
|
|
tt.AssertEqual(t, untaggedPerm.ID, permission.ID)
|
|
tt.AssertEqual(t, untaggedPerm.UserID, 3)
|
|
tt.AssertEqual(t, untaggedPerm.PermID, 43)
|
|
}
|
|
})
|
|
})
|
|
})
|
|
|
|
t.Run("testing error cases", func(t *testing.T) {
|
|
err := createTables(dialect, connStr)
|
|
if err != nil {
|
|
t.Fatal("could not create test table!, reason:", err.Error())
|
|
}
|
|
|
|
t.Run("should report error for invalid input types", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err = c.Insert(ctx, usersTable, "foo")
|
|
tt.AssertNotEqual(t, err, nil)
|
|
|
|
err = c.Insert(ctx, usersTable, nullable.String("foo"))
|
|
tt.AssertNotEqual(t, err, nil)
|
|
|
|
err = c.Insert(ctx, usersTable, map[string]interface{}{
|
|
"name": "foo",
|
|
"age": 12,
|
|
})
|
|
tt.AssertNotEqual(t, err, nil)
|
|
|
|
cantInsertSlice := []interface{}{
|
|
&user{Name: "foo", Age: 22},
|
|
&user{Name: "bar", Age: 32},
|
|
}
|
|
err = c.Insert(ctx, usersTable, cantInsertSlice)
|
|
tt.AssertNotEqual(t, err, nil)
|
|
|
|
// We might want to support this in the future, but not for now:
|
|
err = c.Insert(ctx, usersTable, user{Name: "not a ptr to user", Age: 42})
|
|
tt.AssertNotEqual(t, err, nil)
|
|
})
|
|
|
|
t.Run("should report error if for some reason the InsertMethod is invalid", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
// This is an invalid value:
|
|
c.dialect = brokenDialect{}
|
|
|
|
err = c.Insert(ctx, usersTable, &user{Name: "foo"})
|
|
tt.AssertNotEqual(t, err, nil)
|
|
})
|
|
|
|
t.Run("should report error if it receives a nil pointer to a struct", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
var u *user
|
|
err := c.Insert(ctx, usersTable, u)
|
|
tt.AssertNotEqual(t, err, nil)
|
|
})
|
|
|
|
t.Run("should report error if table contains an empty ID name", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err := c.Insert(ctx, NewTable("users", ""), &user{Name: "fake-name"})
|
|
tt.AssertErrContains(t, err, "ksql.Table", "ID", "empty string")
|
|
})
|
|
|
|
t.Run("should report error if ksql.Table.name is empty", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err := c.Insert(ctx, NewTable("", "id"), &user{Name: "fake-name"})
|
|
tt.AssertErrContains(t, err, "ksql.Table", "table name", "empty string")
|
|
})
|
|
|
|
t.Run("should not panic if a column doesn't exist in the database", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err = c.Insert(ctx, usersTable, &struct {
|
|
ID string `ksql:"id"`
|
|
NonExistingColumn int `ksql:"non_existing"`
|
|
Name string `ksql:"name"`
|
|
}{NonExistingColumn: 42, Name: "fake-name"})
|
|
|
|
tt.AssertErrContains(t, err, "column", "non_existing")
|
|
})
|
|
|
|
t.Run("should not panic if the ID column doesn't exist in the database", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
brokenTable := NewTable("users", "non_existing_id")
|
|
|
|
_ = c.Insert(ctx, brokenTable, &struct {
|
|
ID string `ksql:"non_existing_id"`
|
|
Age int `ksql:"age"`
|
|
Name string `ksql:"name"`
|
|
}{Age: 42, Name: "fake-name"})
|
|
})
|
|
|
|
t.Run("should not panic if the ID column is missing in the struct", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err = c.Insert(ctx, usersTable, &struct {
|
|
Age int `ksql:"age"`
|
|
Name string `ksql:"name"`
|
|
}{Age: 42, Name: "Inserted With no ID"})
|
|
tt.AssertNoErr(t, err)
|
|
|
|
var u user
|
|
err = getUserByName(db, dialect, &u, "Inserted With no ID")
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, u.ID, uint(0))
|
|
tt.AssertEqual(t, u.Age, 42)
|
|
})
|
|
|
|
t.Run("should report error context.Canceled the context has been canceled", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
cancel()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err = c.Insert(ctx, usersTable, &user{Age: 42, Name: "FakeUserName"})
|
|
|
|
tt.AssertEqual(t, errors.Is(err, context.Canceled), true)
|
|
})
|
|
})
|
|
})
|
|
}
|
|
|
|
type brokenDialect struct{}
|
|
|
|
func (brokenDialect) InsertMethod() sqldialect.InsertMethod {
|
|
return sqldialect.InsertMethod(42)
|
|
}
|
|
|
|
func (brokenDialect) Escape(str string) string {
|
|
return str
|
|
}
|
|
|
|
func (brokenDialect) Placeholder(idx int) string {
|
|
return "?"
|
|
}
|
|
|
|
func (brokenDialect) DriverName() string {
|
|
return "fake-driver-name"
|
|
}
|
|
|
|
// DeleteTest runs all tests for making sure the Delete function is
|
|
// working for a given adapter and dialect.
|
|
func DeleteTest(
|
|
t *testing.T,
|
|
dialect sqldialect.Provider,
|
|
connStr string,
|
|
newDBAdapter func(t *testing.T) (DBAdapter, io.Closer),
|
|
) {
|
|
ctx := context.Background()
|
|
|
|
t.Run("Delete", func(t *testing.T) {
|
|
err := createTables(dialect, connStr)
|
|
if err != nil {
|
|
t.Fatal("could not create test table!, reason:", err.Error())
|
|
}
|
|
|
|
t.Run("should delete from tables with a single primary key correctly", func(t *testing.T) {
|
|
tests := []struct {
|
|
desc string
|
|
deletionKeyForUser func(u user) interface{}
|
|
}{
|
|
{
|
|
desc: "passing only the ID as key",
|
|
deletionKeyForUser: func(u user) interface{} {
|
|
return u.ID
|
|
},
|
|
},
|
|
{
|
|
desc: "passing only the entire user",
|
|
deletionKeyForUser: func(u user) interface{} {
|
|
return u
|
|
},
|
|
},
|
|
{
|
|
desc: "passing the address of the user",
|
|
deletionKeyForUser: func(u user) interface{} {
|
|
return &u
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.desc, func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
u1 := user{
|
|
Name: "Fernanda",
|
|
}
|
|
|
|
err := c.Insert(ctx, usersTable, &u1)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, u1.ID, uint(0))
|
|
|
|
result := user{}
|
|
err = getUserByID(c.db, c.dialect, &result, u1.ID)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, result.ID, u1.ID)
|
|
|
|
u2 := user{
|
|
Name: "Won't be deleted",
|
|
}
|
|
|
|
err = c.Insert(ctx, usersTable, &u2)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, u2.ID, uint(0))
|
|
|
|
result = user{}
|
|
err = getUserByID(c.db, c.dialect, &result, u2.ID)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, result.ID, u2.ID)
|
|
|
|
err = c.Delete(ctx, usersTable, test.deletionKeyForUser(u1))
|
|
tt.AssertNoErr(t, err)
|
|
|
|
result = user{}
|
|
err = getUserByID(c.db, c.dialect, &result, u1.ID)
|
|
tt.AssertEqual(t, err, sql.ErrNoRows)
|
|
|
|
result = user{}
|
|
err = getUserByID(c.db, c.dialect, &result, u2.ID)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
tt.AssertNotEqual(t, result.ID, uint(0))
|
|
tt.AssertEqual(t, result.Name, "Won't be deleted")
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("should delete from tables with composite primary keys correctly", func(t *testing.T) {
|
|
t.Run("using structs", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
// This permission should not be deleted, we'll use the id to check it:
|
|
p0 := userPermission{
|
|
UserID: 1,
|
|
PermID: 44,
|
|
}
|
|
err = createUserPermission(db, c.dialect, p0)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
p0, err = getUserPermissionBySecondaryKeys(db, c.dialect, 1, 44)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, p0.ID, 0)
|
|
|
|
p1 := userPermission{
|
|
UserID: 1,
|
|
PermID: 42,
|
|
}
|
|
err = createUserPermission(db, c.dialect, p1)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
err = c.Delete(ctx, userPermissionsTable, p1)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
userPerms, err := getUserPermissionsByUser(db, dialect, 1)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, len(userPerms), 1)
|
|
tt.AssertEqual(t, userPerms[0].UserID, 1)
|
|
tt.AssertEqual(t, userPerms[0].PermID, 44)
|
|
})
|
|
|
|
t.Run("using maps", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
// This permission should not be deleted, we'll use the id to check it:
|
|
p0 := userPermission{
|
|
UserID: 2,
|
|
PermID: 44,
|
|
}
|
|
err = createUserPermission(db, c.dialect, p0)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
p0, err = getUserPermissionBySecondaryKeys(db, c.dialect, 1, 44)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, p0.ID, 0)
|
|
|
|
p1 := userPermission{
|
|
UserID: 2,
|
|
PermID: 42,
|
|
}
|
|
err = createUserPermission(db, c.dialect, p1)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
err = c.Delete(ctx, userPermissionsTable, map[string]interface{}{
|
|
"user_id": 2,
|
|
"perm_id": 42,
|
|
})
|
|
tt.AssertNoErr(t, err)
|
|
|
|
userPerms, err := getUserPermissionsByUser(db, dialect, 2)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, len(userPerms), 1)
|
|
tt.AssertEqual(t, userPerms[0].UserID, 2)
|
|
tt.AssertEqual(t, userPerms[0].PermID, 44)
|
|
})
|
|
})
|
|
|
|
t.Run("should return ErrRecordNotFound if no rows were deleted", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err = c.Delete(ctx, usersTable, 4200)
|
|
tt.AssertEqual(t, err, ErrRecordNotFound)
|
|
})
|
|
|
|
t.Run("should report error if it receives a nil pointer to a struct", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
var u *user
|
|
err := c.Delete(ctx, usersTable, u)
|
|
tt.AssertNotEqual(t, err, nil)
|
|
})
|
|
|
|
t.Run("should report error if one of the ids is missing from the input", func(t *testing.T) {
|
|
t.Run("single id", func(t *testing.T) {
|
|
t.Run("struct with missing attr", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err := c.Delete(ctx, NewTable("users", "id"), &struct {
|
|
// Missing ID
|
|
Name string `ksql:"name"`
|
|
}{Name: "fake-name"})
|
|
tt.AssertErrContains(t, err, "missing required", "id")
|
|
})
|
|
|
|
t.Run("struct with NULL attr", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err := c.Delete(ctx, NewTable("users", "id"), &struct {
|
|
// Null ID
|
|
ID *int `ksql:"id"`
|
|
Name string `ksql:"name"`
|
|
}{Name: "fake-name"})
|
|
tt.AssertErrContains(t, err, "missing required", "id")
|
|
})
|
|
|
|
t.Run("struct with zero attr", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err := c.Delete(ctx, NewTable("users", "id"), &struct {
|
|
// Uninitialized ID
|
|
ID int `ksql:"id"`
|
|
Name string `ksql:"name"`
|
|
}{Name: "fake-name"})
|
|
tt.AssertErrContains(t, err, "invalid value", "0", "id")
|
|
})
|
|
})
|
|
|
|
t.Run("multiple ids", func(t *testing.T) {
|
|
t.Run("struct with missing attr", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err := c.Delete(ctx, userPermissionsTable, &struct {
|
|
// Missing PermID
|
|
UserID int `ksql:"user_id"`
|
|
Name string `ksql:"name"`
|
|
}{
|
|
UserID: 1,
|
|
Name: "fake-name",
|
|
})
|
|
tt.AssertErrContains(t, err, "missing required", "perm_id")
|
|
})
|
|
|
|
t.Run("map with missing attr", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err := c.Delete(ctx, userPermissionsTable, map[string]interface{}{
|
|
// Missing PermID
|
|
"user_id": 1,
|
|
"name": "fake-name",
|
|
})
|
|
tt.AssertErrContains(t, err, "missing required", "perm_id")
|
|
})
|
|
|
|
t.Run("struct with NULL attr", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err := c.Delete(ctx, userPermissionsTable, &struct {
|
|
UserID int `ksql:"user_id"`
|
|
PermID *int `ksql:"perm_id"`
|
|
Name string `ksql:"name"`
|
|
}{
|
|
// Null Perm ID
|
|
UserID: 1,
|
|
PermID: nil,
|
|
Name: "fake-name",
|
|
})
|
|
tt.AssertErrContains(t, err, "missing required", "perm_id")
|
|
})
|
|
|
|
t.Run("map with NULL attr", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err := c.Delete(ctx, userPermissionsTable, map[string]interface{}{
|
|
// Null Perm ID
|
|
"user_id": 1,
|
|
"perm_id": nil,
|
|
"name": "fake-name",
|
|
})
|
|
tt.AssertErrContains(t, err, "invalid value", "nil", "perm_id")
|
|
})
|
|
|
|
t.Run("struct with zero attr", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err := c.Delete(ctx, userPermissionsTable, &struct {
|
|
UserID int `ksql:"user_id"`
|
|
PermID int `ksql:"perm_id"`
|
|
Name string `ksql:"name"`
|
|
}{
|
|
// Zero Perm ID
|
|
UserID: 1,
|
|
PermID: 0,
|
|
Name: "fake-name",
|
|
})
|
|
tt.AssertErrContains(t, err, "invalid value", "0", "perm_id")
|
|
})
|
|
|
|
t.Run("map with zero attr", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err := c.Delete(ctx, userPermissionsTable, map[string]interface{}{
|
|
// Zero Perm ID
|
|
"user_id": 1,
|
|
"perm_id": 0,
|
|
"name": "fake-name",
|
|
})
|
|
tt.AssertErrContains(t, err, "invalid value", "0", "perm_id")
|
|
})
|
|
})
|
|
})
|
|
|
|
t.Run("should report error if table contains an empty ID name", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err := c.Delete(ctx, NewTable("users", ""), &user{ID: 42, Name: "fake-name"})
|
|
tt.AssertErrContains(t, err, "ksql.Table", "ID", "empty string")
|
|
})
|
|
|
|
t.Run("should report error if ksql.Table.name is empty", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err := c.Delete(ctx, NewTable("", "id"), &user{Name: "fake-name"})
|
|
tt.AssertErrContains(t, err, "ksql.Table", "table name", "empty string")
|
|
})
|
|
|
|
t.Run("should report error context.Canceled the context has been canceled", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
cancel()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err := c.Delete(ctx, usersTable, &user{ID: 42, Name: "fake-name"})
|
|
tt.AssertEqual(t, errors.Is(err, context.Canceled), true)
|
|
})
|
|
})
|
|
}
|
|
|
|
// PatchTest runs all tests for making sure the Patch function is
|
|
// working for a given adapter and dialect.
|
|
func PatchTest(
|
|
t *testing.T,
|
|
dialect sqldialect.Provider,
|
|
connStr string,
|
|
newDBAdapter func(t *testing.T) (DBAdapter, io.Closer),
|
|
) {
|
|
ctx := context.Background()
|
|
|
|
t.Run("Patch", func(t *testing.T) {
|
|
err := createTables(dialect, connStr)
|
|
if err != nil {
|
|
t.Fatal("could not create test table!, reason:", err.Error())
|
|
}
|
|
|
|
t.Run("should update one user{} correctly", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
u := user{
|
|
Name: "Letícia",
|
|
}
|
|
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Letícia', 0)`)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
err = getUserByName(db, dialect, &u, "Letícia")
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, u.ID, uint(0))
|
|
|
|
err = c.Patch(ctx, usersTable, user{
|
|
ID: u.ID,
|
|
Name: "Thayane",
|
|
})
|
|
tt.AssertNoErr(t, err)
|
|
|
|
var result user
|
|
err = getUserByID(c.db, c.dialect, &result, u.ID)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, result.Name, "Thayane")
|
|
})
|
|
|
|
t.Run("should update one &user{} correctly", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
u := user{
|
|
Name: "Letícia",
|
|
}
|
|
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Letícia', 0)`)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
err = getUserByName(db, dialect, &u, "Letícia")
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, u.ID, uint(0))
|
|
|
|
err = c.Patch(ctx, usersTable, &user{
|
|
ID: u.ID,
|
|
Name: "Thayane",
|
|
})
|
|
tt.AssertNoErr(t, err)
|
|
|
|
var result user
|
|
err = getUserByID(c.db, c.dialect, &result, u.ID)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, result.Name, "Thayane")
|
|
})
|
|
|
|
t.Run("should update tables with composite keys correctly", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err = createUserPermission(db, c.dialect, userPermission{
|
|
UserID: 42,
|
|
PermID: 43,
|
|
Type: "existingFakeType",
|
|
})
|
|
tt.AssertNoErr(t, err)
|
|
|
|
existingPerm, err := getUserPermissionBySecondaryKeys(db, c.dialect, 42, 43)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, existingPerm.ID, 0)
|
|
tt.AssertEqual(t, existingPerm.Type, "existingFakeType")
|
|
|
|
err = c.Patch(ctx, NewTable("user_permissions", "id", "user_id", "perm_id"), &userPermission{
|
|
ID: existingPerm.ID,
|
|
UserID: 42,
|
|
PermID: 43,
|
|
Type: "newFakeType",
|
|
})
|
|
tt.AssertNoErr(t, err)
|
|
|
|
newPerm, err := getUserPermissionBySecondaryKeys(db, c.dialect, 42, 43)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, newPerm, userPermission{
|
|
ID: existingPerm.ID,
|
|
UserID: 42,
|
|
PermID: 43,
|
|
Type: "newFakeType",
|
|
})
|
|
})
|
|
|
|
t.Run("should ignore null pointers on partial updates", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
type partialUser struct {
|
|
ID uint `ksql:"id"`
|
|
Name string `ksql:"name"`
|
|
Age *int `ksql:"age"`
|
|
}
|
|
|
|
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Letícia', 22)`)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
var u user
|
|
err = getUserByName(db, dialect, &u, "Letícia")
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, u.ID, uint(0))
|
|
|
|
err = c.Patch(ctx, usersTable, partialUser{
|
|
ID: u.ID,
|
|
// Should be updated because it is not null, just empty:
|
|
Name: "",
|
|
// Should not be updated because it is null:
|
|
Age: nil,
|
|
})
|
|
tt.AssertNoErr(t, err)
|
|
|
|
var result user
|
|
err = getUserByID(c.db, c.dialect, &result, u.ID)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, result.Name, "")
|
|
tt.AssertEqual(t, result.Age, 22)
|
|
})
|
|
|
|
t.Run("should update valid pointers on partial updates", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
type partialUser struct {
|
|
ID uint `ksql:"id"`
|
|
Name string `ksql:"name"`
|
|
Age *int `ksql:"age"`
|
|
}
|
|
|
|
_, err := db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Letícia', 22)`)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
var u user
|
|
err = getUserByName(db, dialect, &u, "Letícia")
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, u.ID, uint(0))
|
|
|
|
// Should update all fields:
|
|
err = c.Patch(ctx, usersTable, partialUser{
|
|
ID: u.ID,
|
|
Name: "Thay",
|
|
Age: nullable.Int(42),
|
|
})
|
|
tt.AssertNoErr(t, err)
|
|
|
|
var result user
|
|
err = getUserByID(c.db, c.dialect, &result, u.ID)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
tt.AssertEqual(t, result.Name, "Thay")
|
|
tt.AssertEqual(t, result.Age, 42)
|
|
})
|
|
|
|
t.Run("should return ErrRecordNotFound when asked to update an inexistent user", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err = c.Patch(ctx, usersTable, user{
|
|
ID: 4200,
|
|
Name: "Thayane",
|
|
})
|
|
tt.AssertEqual(t, err, ErrRecordNotFound)
|
|
})
|
|
|
|
t.Run("should report database errors correctly", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err = c.Patch(ctx, NewTable("non_existing_table"), user{
|
|
ID: 1,
|
|
Name: "Thayane",
|
|
})
|
|
tt.AssertNotEqual(t, err, nil)
|
|
})
|
|
|
|
t.Run("should report error if it receives a nil pointer to a struct", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
var u *user
|
|
err := c.Patch(ctx, usersTable, u)
|
|
tt.AssertNotEqual(t, err, nil)
|
|
})
|
|
|
|
t.Run("should report error if the struct has no fields to update", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err = c.Patch(ctx, usersTable, struct {
|
|
ID uint `ksql:"id"` // ID fields are not updated
|
|
Name string `ksql:"name,skipUpdates"` // the skipUpdate modifier should rule this one out
|
|
Age *int `ksql:"age"` // Age is a nil pointer so it would not be updated
|
|
}{
|
|
ID: 1,
|
|
Name: "some name",
|
|
})
|
|
tt.AssertErrContains(t, err, "struct", "no values to update")
|
|
})
|
|
|
|
t.Run("should report error if the id is missing", func(t *testing.T) {
|
|
t.Run("with a single primary key", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err := c.Patch(ctx, usersTable, &user{
|
|
// Missing ID
|
|
Name: "Jane",
|
|
})
|
|
tt.AssertErrContains(t, err, "invalid value", "0", "'id'")
|
|
})
|
|
|
|
t.Run("with composite keys", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err := c.Patch(ctx, NewTable("user_permissions", "id", "user_id", "perm_id"), &userPermission{
|
|
ID: 1,
|
|
// Missing UserID
|
|
PermID: 42,
|
|
})
|
|
tt.AssertErrContains(t, err, "invalid value", "0", "'user_id'")
|
|
})
|
|
})
|
|
|
|
t.Run("should report error if the struct has no id", func(t *testing.T) {
|
|
t.Run("with a single primary key", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err := c.Patch(ctx, usersTable, &struct {
|
|
// Missing ID
|
|
Name string `ksql:"name"`
|
|
}{
|
|
Name: "Jane",
|
|
})
|
|
tt.AssertErrContains(t, err, "missing", "ID fields", "id")
|
|
})
|
|
|
|
t.Run("with composite keys", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err := c.Patch(ctx, NewTable("user_permissions", "id", "user_id", "perm_id"), &struct {
|
|
ID int `ksql:"id"`
|
|
// Missing UserID
|
|
PermID int `ksql:"perm_id"`
|
|
}{
|
|
ID: 1,
|
|
PermID: 42,
|
|
})
|
|
tt.AssertErrContains(t, err, "missing", "ID fields", "user_id")
|
|
})
|
|
})
|
|
|
|
t.Run("should report error context.Canceled the context has been canceled", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
cancel()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err = c.Patch(ctx, usersTable, user{
|
|
ID: 1,
|
|
Name: "Thayane",
|
|
})
|
|
tt.AssertEqual(t, errors.Is(err, context.Canceled), true)
|
|
})
|
|
})
|
|
}
|
|
|
|
// QueryChunksTest runs all tests for making sure the QueryChunks function is
|
|
// working for a given adapter and dialect.
|
|
func QueryChunksTest(
|
|
t *testing.T,
|
|
dialect sqldialect.Provider,
|
|
connStr string,
|
|
newDBAdapter func(t *testing.T) (DBAdapter, io.Closer),
|
|
) {
|
|
ctx := context.Background()
|
|
|
|
t.Run("QueryChunks", func(t *testing.T) {
|
|
variations := []struct {
|
|
desc string
|
|
queryPrefix string
|
|
}{
|
|
{
|
|
desc: "with select *",
|
|
queryPrefix: "SELECT * ",
|
|
},
|
|
{
|
|
desc: "building the SELECT part of the query internally",
|
|
queryPrefix: "",
|
|
},
|
|
}
|
|
for _, variation := range variations {
|
|
t.Run(variation.desc, func(t *testing.T) {
|
|
t.Run("should query a single row correctly", func(t *testing.T) {
|
|
err := createTables(dialect, connStr)
|
|
if err != nil {
|
|
t.Fatal("could not create test table!, reason:", err.Error())
|
|
}
|
|
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
_ = c.Insert(ctx, usersTable, &user{
|
|
Name: "User1",
|
|
Address: address{Country: "BR"},
|
|
})
|
|
|
|
var length int
|
|
var u user
|
|
err = c.QueryChunks(ctx, ChunkParser{
|
|
Query: variation.queryPrefix + `FROM users WHERE name = ` + c.dialect.Placeholder(0),
|
|
Params: []interface{}{"User1"},
|
|
|
|
ChunkSize: 100,
|
|
ForEachChunk: func(users []user) error {
|
|
length = len(users)
|
|
if length > 0 {
|
|
u = users[0]
|
|
}
|
|
return nil
|
|
},
|
|
})
|
|
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, length, 1)
|
|
tt.AssertNotEqual(t, u.ID, uint(0))
|
|
tt.AssertEqual(t, u.Name, "User1")
|
|
tt.AssertEqual(t, u.Address.Country, "BR")
|
|
})
|
|
|
|
t.Run("should query one chunk correctly", func(t *testing.T) {
|
|
err := createTables(dialect, connStr)
|
|
if err != nil {
|
|
t.Fatal("could not create test table!, reason:", err.Error())
|
|
}
|
|
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
_ = c.Insert(ctx, usersTable, &user{Name: "User1", Address: address{Country: "US"}})
|
|
_ = c.Insert(ctx, usersTable, &user{Name: "User2", Address: address{Country: "BR"}})
|
|
|
|
var lengths []int
|
|
var users []user
|
|
err = c.QueryChunks(ctx, ChunkParser{
|
|
Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`,
|
|
Params: []interface{}{"User%"},
|
|
|
|
ChunkSize: 2,
|
|
ForEachChunk: func(buffer []user) error {
|
|
users = append(users, buffer...)
|
|
lengths = append(lengths, len(buffer))
|
|
return nil
|
|
},
|
|
})
|
|
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, len(lengths), 1)
|
|
tt.AssertEqual(t, lengths[0], 2)
|
|
|
|
tt.AssertNotEqual(t, users[0].ID, uint(0))
|
|
tt.AssertEqual(t, users[0].Name, "User1")
|
|
tt.AssertEqual(t, users[0].Address.Country, "US")
|
|
|
|
tt.AssertNotEqual(t, users[1].ID, uint(0))
|
|
tt.AssertEqual(t, users[1].Name, "User2")
|
|
tt.AssertEqual(t, users[1].Address.Country, "BR")
|
|
})
|
|
|
|
t.Run("should query chunks of 1 correctly", func(t *testing.T) {
|
|
err := createTables(dialect, connStr)
|
|
if err != nil {
|
|
t.Fatal("could not create test table!, reason:", err.Error())
|
|
}
|
|
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
_ = c.Insert(ctx, usersTable, &user{Name: "User1", Address: address{Country: "US"}})
|
|
_ = c.Insert(ctx, usersTable, &user{Name: "User2", Address: address{Country: "BR"}})
|
|
|
|
var lengths []int
|
|
var users []user
|
|
err = c.QueryChunks(ctx, ChunkParser{
|
|
Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`,
|
|
Params: []interface{}{"User%"},
|
|
|
|
ChunkSize: 1,
|
|
ForEachChunk: func(buffer []user) error {
|
|
lengths = append(lengths, len(buffer))
|
|
users = append(users, buffer...)
|
|
return nil
|
|
},
|
|
})
|
|
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, len(users), 2)
|
|
tt.AssertEqual(t, lengths, []int{1, 1})
|
|
|
|
tt.AssertNotEqual(t, users[0].ID, uint(0))
|
|
tt.AssertEqual(t, users[0].Name, "User1")
|
|
tt.AssertEqual(t, users[0].Address.Country, "US")
|
|
|
|
tt.AssertNotEqual(t, users[1].ID, uint(0))
|
|
tt.AssertEqual(t, users[1].Name, "User2")
|
|
tt.AssertEqual(t, users[1].Address.Country, "BR")
|
|
})
|
|
|
|
t.Run("should load partially filled chunks correctly", func(t *testing.T) {
|
|
err := createTables(dialect, connStr)
|
|
if err != nil {
|
|
t.Fatal("could not create test table!, reason:", err.Error())
|
|
}
|
|
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
_ = c.Insert(ctx, usersTable, &user{Name: "User1"})
|
|
_ = c.Insert(ctx, usersTable, &user{Name: "User2"})
|
|
_ = c.Insert(ctx, usersTable, &user{Name: "User3"})
|
|
|
|
var lengths []int
|
|
var users []user
|
|
err = c.QueryChunks(ctx, ChunkParser{
|
|
Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`,
|
|
Params: []interface{}{"User%"},
|
|
|
|
ChunkSize: 2,
|
|
ForEachChunk: func(buffer []user) error {
|
|
lengths = append(lengths, len(buffer))
|
|
users = append(users, buffer...)
|
|
return nil
|
|
},
|
|
})
|
|
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, len(users), 3)
|
|
tt.AssertNotEqual(t, users[0].ID, uint(0))
|
|
tt.AssertEqual(t, users[0].Name, "User1")
|
|
tt.AssertNotEqual(t, users[1].ID, uint(0))
|
|
tt.AssertEqual(t, users[1].Name, "User2")
|
|
tt.AssertNotEqual(t, users[2].ID, uint(0))
|
|
tt.AssertEqual(t, users[2].Name, "User3")
|
|
tt.AssertEqual(t, lengths, []int{2, 1})
|
|
})
|
|
|
|
// xxx
|
|
t.Run("should query joined tables correctly", func(t *testing.T) {
|
|
// This test only makes sense with no query prefix
|
|
if variation.queryPrefix != "" {
|
|
return
|
|
}
|
|
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
joao := user{
|
|
Name: "Thiago Ribeiro",
|
|
Age: 24,
|
|
}
|
|
thatiana := user{
|
|
Name: "Thatiana Ribeiro",
|
|
Age: 20,
|
|
}
|
|
|
|
c := newTestDB(db, dialect)
|
|
_ = c.Insert(ctx, usersTable, &joao)
|
|
_ = c.Insert(ctx, usersTable, &thatiana)
|
|
|
|
_, err := db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, thatiana.ID, `, 'Thatiana Post1')`))
|
|
tt.AssertNoErr(t, err)
|
|
_, err = db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, thatiana.ID, `, 'Thatiana Post2')`))
|
|
tt.AssertNoErr(t, err)
|
|
_, err = db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, joao.ID, `, 'Thiago Post1')`))
|
|
tt.AssertNoErr(t, err)
|
|
|
|
var lengths []int
|
|
var users []user
|
|
var posts []post
|
|
err = c.QueryChunks(ctx, ChunkParser{
|
|
Query: fmt.Sprint(
|
|
`FROM users u JOIN posts p ON p.user_id = u.id`,
|
|
` WHERE u.name like `, c.dialect.Placeholder(0),
|
|
` ORDER BY u.id, p.id`,
|
|
),
|
|
Params: []interface{}{"% Ribeiro"},
|
|
|
|
ChunkSize: 2,
|
|
ForEachChunk: func(chunk []struct {
|
|
User user `tablename:"u"`
|
|
Post post `tablename:"p"`
|
|
}) error {
|
|
lengths = append(lengths, len(chunk))
|
|
for _, row := range chunk {
|
|
users = append(users, row.User)
|
|
posts = append(posts, row.Post)
|
|
}
|
|
return nil
|
|
},
|
|
})
|
|
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, len(posts), 3)
|
|
|
|
tt.AssertEqual(t, users[0].ID, joao.ID)
|
|
tt.AssertEqual(t, users[0].Name, "Thiago Ribeiro")
|
|
tt.AssertEqual(t, posts[0].Title, "Thiago Post1")
|
|
|
|
tt.AssertEqual(t, users[1].ID, thatiana.ID)
|
|
tt.AssertEqual(t, users[1].Name, "Thatiana Ribeiro")
|
|
tt.AssertEqual(t, posts[1].Title, "Thatiana Post1")
|
|
|
|
tt.AssertEqual(t, users[2].ID, thatiana.ID)
|
|
tt.AssertEqual(t, users[2].Name, "Thatiana Ribeiro")
|
|
tt.AssertEqual(t, posts[2].Title, "Thatiana Post2")
|
|
})
|
|
|
|
t.Run("should abort the first iteration when the callback returns an ErrAbortIteration", func(t *testing.T) {
|
|
err := createTables(dialect, connStr)
|
|
if err != nil {
|
|
t.Fatal("could not create test table!, reason:", err.Error())
|
|
}
|
|
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
_ = c.Insert(ctx, usersTable, &user{Name: "User1"})
|
|
_ = c.Insert(ctx, usersTable, &user{Name: "User2"})
|
|
_ = c.Insert(ctx, usersTable, &user{Name: "User3"})
|
|
|
|
var lengths []int
|
|
var users []user
|
|
err = c.QueryChunks(ctx, ChunkParser{
|
|
Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`,
|
|
Params: []interface{}{"User%"},
|
|
|
|
ChunkSize: 2,
|
|
ForEachChunk: func(buffer []user) error {
|
|
lengths = append(lengths, len(buffer))
|
|
users = append(users, buffer...)
|
|
return ErrAbortIteration
|
|
},
|
|
})
|
|
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, len(users), 2)
|
|
tt.AssertNotEqual(t, users[0].ID, uint(0))
|
|
tt.AssertEqual(t, users[0].Name, "User1")
|
|
tt.AssertNotEqual(t, users[1].ID, uint(0))
|
|
tt.AssertEqual(t, users[1].Name, "User2")
|
|
tt.AssertEqual(t, lengths, []int{2})
|
|
})
|
|
|
|
t.Run("should abort the last iteration when the callback returns an ErrAbortIteration", func(t *testing.T) {
|
|
err := createTables(dialect, connStr)
|
|
if err != nil {
|
|
t.Fatal("could not create test table!, reason:", err.Error())
|
|
}
|
|
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
_ = c.Insert(ctx, usersTable, &user{Name: "User1"})
|
|
_ = c.Insert(ctx, usersTable, &user{Name: "User2"})
|
|
_ = c.Insert(ctx, usersTable, &user{Name: "User3"})
|
|
|
|
returnVals := []error{nil, ErrAbortIteration}
|
|
var lengths []int
|
|
var users []user
|
|
err = c.QueryChunks(ctx, ChunkParser{
|
|
Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`,
|
|
Params: []interface{}{"User%"},
|
|
|
|
ChunkSize: 2,
|
|
ForEachChunk: func(buffer []user) error {
|
|
lengths = append(lengths, len(buffer))
|
|
users = append(users, buffer...)
|
|
|
|
return shiftErrSlice(&returnVals)
|
|
},
|
|
})
|
|
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, len(users), 3)
|
|
tt.AssertNotEqual(t, users[0].ID, uint(0))
|
|
tt.AssertEqual(t, users[0].Name, "User1")
|
|
tt.AssertNotEqual(t, users[1].ID, uint(0))
|
|
tt.AssertEqual(t, users[1].Name, "User2")
|
|
tt.AssertNotEqual(t, users[2].ID, uint(0))
|
|
tt.AssertEqual(t, users[2].Name, "User3")
|
|
tt.AssertEqual(t, lengths, []int{2, 1})
|
|
})
|
|
|
|
t.Run("should return error if the callback returns an error in the first iteration", func(t *testing.T) {
|
|
err := createTables(dialect, connStr)
|
|
if err != nil {
|
|
t.Fatal("could not create test table!, reason:", err.Error())
|
|
}
|
|
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
_ = c.Insert(ctx, usersTable, &user{Name: "User1"})
|
|
_ = c.Insert(ctx, usersTable, &user{Name: "User2"})
|
|
_ = c.Insert(ctx, usersTable, &user{Name: "User3"})
|
|
|
|
var lengths []int
|
|
var users []user
|
|
err = c.QueryChunks(ctx, ChunkParser{
|
|
Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`,
|
|
Params: []interface{}{"User%"},
|
|
|
|
ChunkSize: 2,
|
|
ForEachChunk: func(buffer []user) error {
|
|
lengths = append(lengths, len(buffer))
|
|
users = append(users, buffer...)
|
|
return fmt.Errorf("fake error msg")
|
|
},
|
|
})
|
|
|
|
tt.AssertNotEqual(t, err, nil)
|
|
tt.AssertEqual(t, len(users), 2)
|
|
tt.AssertNotEqual(t, users[0].ID, uint(0))
|
|
tt.AssertEqual(t, users[0].Name, "User1")
|
|
tt.AssertNotEqual(t, users[1].ID, uint(0))
|
|
tt.AssertEqual(t, users[1].Name, "User2")
|
|
tt.AssertEqual(t, lengths, []int{2})
|
|
})
|
|
|
|
t.Run("should return error if the callback returns an error in the last iteration", func(t *testing.T) {
|
|
err := createTables(dialect, connStr)
|
|
if err != nil {
|
|
t.Fatal("could not create test table!, reason:", err.Error())
|
|
}
|
|
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
_ = c.Insert(ctx, usersTable, &user{Name: "User1"})
|
|
_ = c.Insert(ctx, usersTable, &user{Name: "User2"})
|
|
_ = c.Insert(ctx, usersTable, &user{Name: "User3"})
|
|
|
|
returnVals := []error{nil, fmt.Errorf("fake error msg")}
|
|
var lengths []int
|
|
var users []user
|
|
err = c.QueryChunks(ctx, ChunkParser{
|
|
Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`,
|
|
Params: []interface{}{"User%"},
|
|
|
|
ChunkSize: 2,
|
|
ForEachChunk: func(buffer []user) error {
|
|
lengths = append(lengths, len(buffer))
|
|
users = append(users, buffer...)
|
|
|
|
return shiftErrSlice(&returnVals)
|
|
},
|
|
})
|
|
|
|
tt.AssertNotEqual(t, err, nil)
|
|
tt.AssertEqual(t, len(users), 3)
|
|
tt.AssertNotEqual(t, users[0].ID, uint(0))
|
|
tt.AssertEqual(t, users[0].Name, "User1")
|
|
tt.AssertNotEqual(t, users[1].ID, uint(0))
|
|
tt.AssertEqual(t, users[1].Name, "User2")
|
|
tt.AssertNotEqual(t, users[2].ID, uint(0))
|
|
tt.AssertEqual(t, users[2].Name, "User3")
|
|
tt.AssertEqual(t, lengths, []int{2, 1})
|
|
})
|
|
|
|
t.Run("should report error if the input function is invalid", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
funcs := []interface{}{
|
|
nil,
|
|
"not a function",
|
|
func() error {
|
|
return nil
|
|
},
|
|
func(extraInputValue []user, extra []user) error {
|
|
return nil
|
|
},
|
|
func(invalidArgType string) error {
|
|
return nil
|
|
},
|
|
func(missingReturnType []user) {
|
|
},
|
|
func(users []user) string {
|
|
return ""
|
|
},
|
|
func(extraReturnValue []user) ([]user, error) {
|
|
return nil, nil
|
|
},
|
|
func(notSliceOfStructs []string) error {
|
|
return nil
|
|
},
|
|
}
|
|
|
|
for _, fn := range funcs {
|
|
err := c.QueryChunks(ctx, ChunkParser{
|
|
Query: variation.queryPrefix + `FROM users`,
|
|
Params: []interface{}{},
|
|
|
|
ChunkSize: 2,
|
|
ForEachChunk: fn,
|
|
})
|
|
tt.AssertNotEqual(t, err, nil)
|
|
}
|
|
})
|
|
|
|
t.Run("should report error if the query is not valid", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
err := c.QueryChunks(ctx, ChunkParser{
|
|
Query: `SELECT * FROM not a valid query`,
|
|
Params: []interface{}{},
|
|
|
|
ChunkSize: 2,
|
|
ForEachChunk: func(buffer []user) error {
|
|
return nil
|
|
},
|
|
})
|
|
tt.AssertNotEqual(t, err, nil)
|
|
})
|
|
|
|
t.Run("should report error if using nested struct and the query starts with SELECT", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err := c.QueryChunks(ctx, ChunkParser{
|
|
Query: `SELECT * FROM users u JOIN posts p ON u.id = p.user_id`,
|
|
Params: []interface{}{},
|
|
|
|
ChunkSize: 2,
|
|
ForEachChunk: func(buffer []struct {
|
|
User user `tablename:"users"`
|
|
Post post `tablename:"posts"`
|
|
}) error {
|
|
return nil
|
|
},
|
|
})
|
|
|
|
tt.AssertErrContains(t, err, "nested struct", "feature")
|
|
})
|
|
})
|
|
}
|
|
|
|
t.Run("error cases", func(t *testing.T) {
|
|
t.Run("should report error context.Canceled the context has been canceled", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
cancel()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err := c.QueryChunks(ctx, ChunkParser{
|
|
Query: `SELECT * FROM users u JOIN posts p ON u.id = p.user_id`,
|
|
Params: []interface{}{},
|
|
|
|
ChunkSize: 2,
|
|
ForEachChunk: func(rows []user) error {
|
|
return nil
|
|
},
|
|
})
|
|
tt.AssertEqual(t, errors.Is(err, context.Canceled), true)
|
|
})
|
|
})
|
|
})
|
|
}
|
|
|
|
// TransactionTest runs all tests for making sure the Transaction function is
|
|
// working for a given adapter and dialect.
|
|
func TransactionTest(
|
|
t *testing.T,
|
|
dialect sqldialect.Provider,
|
|
connStr string,
|
|
newDBAdapter func(t *testing.T) (DBAdapter, io.Closer),
|
|
) {
|
|
ctx := context.Background()
|
|
|
|
t.Run("Transaction", func(t *testing.T) {
|
|
t.Run("should query a single row correctly", func(t *testing.T) {
|
|
err := createTables(dialect, connStr)
|
|
if err != nil {
|
|
t.Fatal("could not create test table!, reason:", err.Error())
|
|
}
|
|
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
_ = c.Insert(ctx, usersTable, &user{Name: "User1"})
|
|
_ = c.Insert(ctx, usersTable, &user{Name: "User2"})
|
|
|
|
var users []user
|
|
err = c.Transaction(ctx, func(db Provider) error {
|
|
db.Query(ctx, &users, "SELECT * FROM users ORDER BY id ASC")
|
|
return nil
|
|
})
|
|
tt.AssertNoErr(t, err)
|
|
|
|
tt.AssertEqual(t, len(users), 2)
|
|
tt.AssertEqual(t, users[0].Name, "User1")
|
|
tt.AssertEqual(t, users[1].Name, "User2")
|
|
})
|
|
|
|
t.Run("should work normally in nested transactions", func(t *testing.T) {
|
|
err := createTables(dialect, connStr)
|
|
if err != nil {
|
|
t.Fatal("could not create test table!, reason:", err.Error())
|
|
}
|
|
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
u := user{
|
|
Name: "User1",
|
|
}
|
|
err = c.Insert(ctx, usersTable, &u)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, u.ID, 0)
|
|
|
|
var updatedUser user
|
|
err = c.Transaction(ctx, func(db Provider) error {
|
|
u.Age = 42
|
|
err = db.Patch(ctx, usersTable, &u)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return db.Transaction(ctx, func(db Provider) error {
|
|
return db.QueryOne(ctx, &updatedUser, "FROM users WHERE id = "+c.dialect.Placeholder(0), u.ID)
|
|
})
|
|
})
|
|
tt.AssertNoErr(t, err)
|
|
|
|
tt.AssertEqual(t, updatedUser.ID, u.ID)
|
|
tt.AssertEqual(t, updatedUser.Age, 42)
|
|
})
|
|
|
|
t.Run("should rollback when there are errors", func(t *testing.T) {
|
|
err := createTables(dialect, connStr)
|
|
if err != nil {
|
|
t.Fatal("could not create test table!, reason:", err.Error())
|
|
}
|
|
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
u1 := user{Name: "User1", Age: 42}
|
|
u2 := user{Name: "User2", Age: 42}
|
|
_ = c.Insert(ctx, usersTable, &u1)
|
|
_ = c.Insert(ctx, usersTable, &u2)
|
|
|
|
err = c.Transaction(ctx, func(db Provider) error {
|
|
err = db.Insert(ctx, usersTable, &user{Name: "User3"})
|
|
tt.AssertNoErr(t, err)
|
|
err = db.Insert(ctx, usersTable, &user{Name: "User4"})
|
|
tt.AssertNoErr(t, err)
|
|
_, err = db.Exec(ctx, "UPDATE users SET age = 22")
|
|
tt.AssertNoErr(t, err)
|
|
|
|
return fmt.Errorf("fake-error")
|
|
})
|
|
tt.AssertNotEqual(t, err, nil)
|
|
tt.AssertEqual(t, err.Error(), "fake-error")
|
|
|
|
var users []user
|
|
err = c.Query(ctx, &users, "SELECT * FROM users ORDER BY id ASC")
|
|
tt.AssertNoErr(t, err)
|
|
|
|
tt.AssertEqual(t, users, []user{u1, u2})
|
|
})
|
|
|
|
t.Run("should rollback when the fn call panics", func(t *testing.T) {
|
|
err := createTables(dialect, connStr)
|
|
if err != nil {
|
|
t.Fatal("could not create test table!, reason:", err.Error())
|
|
}
|
|
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
u1 := user{Name: "User1", Age: 42}
|
|
u2 := user{Name: "User2", Age: 42}
|
|
_ = c.Insert(ctx, usersTable, &u1)
|
|
_ = c.Insert(ctx, usersTable, &u2)
|
|
|
|
panicPayload := tt.PanicHandler(func() {
|
|
c.Transaction(ctx, func(db Provider) error {
|
|
err = db.Insert(ctx, usersTable, &user{Name: "User3"})
|
|
tt.AssertNoErr(t, err)
|
|
err = db.Insert(ctx, usersTable, &user{Name: "User4"})
|
|
tt.AssertNoErr(t, err)
|
|
_, err = db.Exec(ctx, "UPDATE users SET age = 22")
|
|
tt.AssertNoErr(t, err)
|
|
|
|
panic("fakePanicPayload")
|
|
})
|
|
})
|
|
tt.AssertEqual(t, panicPayload, "fakePanicPayload")
|
|
|
|
var users []user
|
|
err = c.Query(ctx, &users, "SELECT * FROM users ORDER BY id ASC")
|
|
tt.AssertNoErr(t, err)
|
|
|
|
tt.AssertEqual(t, users, []user{u1, u2})
|
|
})
|
|
|
|
t.Run("should handle rollback errors when the fn call panics", func(t *testing.T) {
|
|
err := createTables(dialect, connStr)
|
|
if err != nil {
|
|
t.Fatal("could not create test table!, reason:", err.Error())
|
|
}
|
|
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
cMock := mockTxBeginner{
|
|
DBAdapter: c.db,
|
|
BeginTxFn: func(ctx context.Context) (Tx, error) {
|
|
return mockTx{
|
|
DBAdapter: c.db,
|
|
RollbackFn: func(ctx context.Context) error {
|
|
return fmt.Errorf("fakeRollbackErrMsg")
|
|
},
|
|
}, nil
|
|
},
|
|
}
|
|
c.db = cMock
|
|
|
|
panicPayload := tt.PanicHandler(func() {
|
|
c.Transaction(ctx, func(db Provider) error {
|
|
panic("fakePanicPayload")
|
|
})
|
|
})
|
|
err, ok := panicPayload.(error)
|
|
tt.AssertEqual(t, ok, true)
|
|
tt.AssertErrContains(t, err, "fakePanicPayload", "fakeRollbackErrMsg")
|
|
})
|
|
|
|
t.Run("should handle rollback errors when fn returns an error", func(t *testing.T) {
|
|
err := createTables(dialect, connStr)
|
|
if err != nil {
|
|
t.Fatal("could not create test table!, reason:", err.Error())
|
|
}
|
|
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
cMock := mockTxBeginner{
|
|
DBAdapter: c.db,
|
|
BeginTxFn: func(ctx context.Context) (Tx, error) {
|
|
return mockTx{
|
|
DBAdapter: c.db,
|
|
RollbackFn: func(ctx context.Context) error {
|
|
return fmt.Errorf("fakeRollbackErrMsg")
|
|
},
|
|
}, nil
|
|
},
|
|
}
|
|
c.db = cMock
|
|
|
|
err = c.Transaction(ctx, func(db Provider) error {
|
|
return fmt.Errorf("fakeTransactionErrMsg")
|
|
})
|
|
tt.AssertErrContains(t, err, "fakeTransactionErrMsg", "fakeRollbackErrMsg")
|
|
})
|
|
|
|
t.Run("should report error when BeginTx() fails", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
cMock := mockTxBeginner{
|
|
DBAdapter: c.db,
|
|
BeginTxFn: func(ctx context.Context) (Tx, error) {
|
|
return nil, fmt.Errorf("fakeErrMsg")
|
|
},
|
|
}
|
|
c.db = cMock
|
|
|
|
err := c.Transaction(ctx, func(db Provider) error {
|
|
return nil
|
|
})
|
|
tt.AssertErrContains(t, err, "KSQL", "fakeErrMsg")
|
|
})
|
|
|
|
t.Run("should report error if DBAdapter can't create transactions", func(t *testing.T) {
|
|
err := createTables(dialect, connStr)
|
|
if err != nil {
|
|
t.Fatal("could not create test table!, reason:", err.Error())
|
|
}
|
|
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
c.db = mockDBAdapter{}
|
|
|
|
err = c.Transaction(ctx, func(db Provider) error {
|
|
return nil
|
|
})
|
|
tt.AssertErrContains(t, err, "KSQL", "can't start transaction", "DBAdapter", "TxBeginner")
|
|
})
|
|
|
|
t.Run("should report error context.Canceled the context has been canceled", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
cancel()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
err := c.Transaction(ctx, func(db Provider) error {
|
|
return nil
|
|
})
|
|
tt.AssertEqual(t, errors.Is(err, context.Canceled), true)
|
|
})
|
|
})
|
|
}
|
|
|
|
func ModifiersTest(
|
|
t *testing.T,
|
|
dialect sqldialect.Provider,
|
|
connStr string,
|
|
newDBAdapter func(t *testing.T) (DBAdapter, io.Closer),
|
|
) {
|
|
ctx := context.Background()
|
|
|
|
t.Run("Modifiers", func(t *testing.T) {
|
|
err := createTables(dialect, connStr)
|
|
if err != nil {
|
|
t.Fatal("could not create test table!, reason:", err.Error())
|
|
}
|
|
|
|
t.Run("timeNowUTC modifier", func(t *testing.T) {
|
|
t.Run("should be set to time.Now().UTC() on insertion", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
type tsUser struct {
|
|
ID uint `ksql:"id"`
|
|
Name string `ksql:"name"`
|
|
UpdatedAt time.Time `ksql:"updated_at,timeNowUTC"`
|
|
}
|
|
u := tsUser{
|
|
Name: "Letícia",
|
|
}
|
|
err := c.Insert(ctx, usersTable, &u)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, u.ID, 0)
|
|
|
|
var untaggedUser struct {
|
|
ID uint `ksql:"id"`
|
|
Name string `ksql:"name"`
|
|
UpdatedAt time.Time `ksql:"updated_at"`
|
|
}
|
|
err = c.QueryOne(ctx, &untaggedUser, `FROM users WHERE id = `+c.dialect.Placeholder(0), u.ID)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
now := time.Now()
|
|
tt.AssertApproxTime(t,
|
|
2*time.Second, untaggedUser.UpdatedAt, now,
|
|
"updatedAt should be set to %v, but got: %v", now, untaggedUser.UpdatedAt,
|
|
)
|
|
})
|
|
|
|
t.Run("should be set to time.Now().UTC() on updates", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
type userWithNoTags struct {
|
|
ID uint `ksql:"id"`
|
|
Name string `ksql:"name"`
|
|
UpdatedAt time.Time `ksql:"updated_at"`
|
|
}
|
|
untaggedUser := userWithNoTags{
|
|
Name: "Laura Ribeiro",
|
|
// Any time different from now:
|
|
UpdatedAt: tt.ParseTime(t, "2000-08-05T14:00:00Z"),
|
|
}
|
|
err := c.Insert(ctx, usersTable, &untaggedUser)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, untaggedUser.ID, 0)
|
|
|
|
type taggedUser struct {
|
|
ID uint `ksql:"id"`
|
|
Name string `ksql:"name"`
|
|
UpdatedAt time.Time `ksql:"updated_at,timeNowUTC"`
|
|
}
|
|
u := taggedUser{
|
|
ID: untaggedUser.ID,
|
|
Name: "Laurinha Ribeiro",
|
|
}
|
|
err = c.Patch(ctx, usersTable, u)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
var untaggedUser2 userWithNoTags
|
|
err = c.QueryOne(ctx, &untaggedUser2, "FROM users WHERE id = "+c.dialect.Placeholder(0), u.ID)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, untaggedUser2.ID, 0)
|
|
|
|
now := time.Now()
|
|
tt.AssertApproxTime(t,
|
|
2*time.Second, untaggedUser2.UpdatedAt, now,
|
|
"updatedAt should be set to %v, but got: %v", now, untaggedUser2.UpdatedAt,
|
|
)
|
|
})
|
|
|
|
t.Run("should not alter the value on queries", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
type userWithNoTags struct {
|
|
ID uint `ksql:"id"`
|
|
Name string `ksql:"name"`
|
|
UpdatedAt time.Time `ksql:"updated_at"`
|
|
}
|
|
untaggedUser := userWithNoTags{
|
|
Name: "Marta Ribeiro",
|
|
// Any time different from now:
|
|
UpdatedAt: tt.ParseTime(t, "2000-08-05T14:00:00Z"),
|
|
}
|
|
err := c.Insert(ctx, usersTable, &untaggedUser)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, untaggedUser.ID, 0)
|
|
|
|
var taggedUser struct {
|
|
ID uint `ksql:"id"`
|
|
Name string `ksql:"name"`
|
|
UpdatedAt time.Time `ksql:"updated_at,timeNowUTC"`
|
|
}
|
|
err = c.QueryOne(ctx, &taggedUser, "FROM users WHERE id = "+c.dialect.Placeholder(0), untaggedUser.ID)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, taggedUser.ID, untaggedUser.ID)
|
|
tt.AssertEqual(t, taggedUser.Name, "Marta Ribeiro")
|
|
tt.AssertEqual(t, taggedUser.UpdatedAt, tt.ParseTime(t, "2000-08-05T14:00:00Z"))
|
|
})
|
|
})
|
|
|
|
t.Run("timeNowUTC/skipUpdates modifier", func(t *testing.T) {
|
|
t.Run("should be set to time.Now().UTC() on insertion", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
type tsUser struct {
|
|
ID uint `ksql:"id"`
|
|
Name string `ksql:"name"`
|
|
CreatedAt time.Time `ksql:"created_at,timeNowUTC/skipUpdates"`
|
|
}
|
|
u := tsUser{
|
|
Name: "Letícia",
|
|
}
|
|
err := c.Insert(ctx, usersTable, &u)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, u.ID, 0)
|
|
|
|
var untaggedUser struct {
|
|
ID uint `ksql:"id"`
|
|
Name string `ksql:"name"`
|
|
CreatedAt time.Time `ksql:"created_at"`
|
|
}
|
|
err = c.QueryOne(ctx, &untaggedUser, `FROM users WHERE id = `+c.dialect.Placeholder(0), u.ID)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
now := time.Now()
|
|
tt.AssertApproxTime(t,
|
|
2*time.Second, untaggedUser.CreatedAt, now,
|
|
"updatedAt should be set to %v, but got: %v", now, untaggedUser.CreatedAt,
|
|
)
|
|
})
|
|
|
|
t.Run("should be ignored on updates", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
type userWithNoTags struct {
|
|
ID uint `ksql:"id"`
|
|
Name string `ksql:"name"`
|
|
CreatedAt time.Time `ksql:"created_at"`
|
|
}
|
|
untaggedUser := userWithNoTags{
|
|
Name: "Laura Ribeiro",
|
|
// Any time different from now:
|
|
CreatedAt: tt.ParseTime(t, "2000-08-05T14:00:00Z"),
|
|
}
|
|
err := c.Insert(ctx, usersTable, &untaggedUser)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, untaggedUser.ID, 0)
|
|
|
|
type taggedUser struct {
|
|
ID uint `ksql:"id"`
|
|
Name string `ksql:"name"`
|
|
CreatedAt time.Time `ksql:"created_at,timeNowUTC/skipUpdates"`
|
|
}
|
|
u := taggedUser{
|
|
ID: untaggedUser.ID,
|
|
Name: "Laurinha Ribeiro",
|
|
// Some random time that should be ignored:
|
|
CreatedAt: tt.ParseTime(t, "1999-08-05T14:00:00Z"),
|
|
}
|
|
err = c.Patch(ctx, usersTable, u)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
var untaggedUser2 userWithNoTags
|
|
err = c.QueryOne(ctx, &untaggedUser2, "FROM users WHERE id = "+c.dialect.Placeholder(0), u.ID)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, untaggedUser2.CreatedAt, tt.ParseTime(t, "2000-08-05T14:00:00Z"))
|
|
})
|
|
|
|
t.Run("should not alter the value on queries", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
type userWithNoTags struct {
|
|
ID uint `ksql:"id"`
|
|
Name string `ksql:"name"`
|
|
CreatedAt time.Time `ksql:"created_at"`
|
|
}
|
|
untaggedUser := userWithNoTags{
|
|
Name: "Marta Ribeiro",
|
|
// Any time different from now:
|
|
CreatedAt: tt.ParseTime(t, "2000-08-05T14:00:00Z"),
|
|
}
|
|
err := c.Insert(ctx, usersTable, &untaggedUser)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, untaggedUser.ID, 0)
|
|
|
|
var taggedUser struct {
|
|
ID uint `ksql:"id"`
|
|
Name string `ksql:"name"`
|
|
CreatedAt time.Time `ksql:"created_at,timeNowUTC/skipUpdates"`
|
|
}
|
|
err = c.QueryOne(ctx, &taggedUser, "FROM users WHERE id = "+c.dialect.Placeholder(0), untaggedUser.ID)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, taggedUser.ID, untaggedUser.ID)
|
|
tt.AssertEqual(t, taggedUser.Name, "Marta Ribeiro")
|
|
tt.AssertEqual(t, taggedUser.CreatedAt, tt.ParseTime(t, "2000-08-05T14:00:00Z"))
|
|
})
|
|
})
|
|
|
|
t.Run("skipInserts modifier", func(t *testing.T) {
|
|
t.Run("should ignore the field during insertions", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
type tsUser struct {
|
|
ID uint `ksql:"id"`
|
|
Name string `ksql:"name,skipInserts"`
|
|
Age int `ksql:"age"`
|
|
}
|
|
u := tsUser{
|
|
Name: "Letícia",
|
|
Age: 22,
|
|
}
|
|
err := c.Insert(ctx, usersTable, &u)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, u.ID, 0)
|
|
|
|
var untaggedUser struct {
|
|
ID uint `ksql:"id"`
|
|
Name *string `ksql:"name"`
|
|
Age int `ksql:"age"`
|
|
}
|
|
err = c.QueryOne(ctx, &untaggedUser, `FROM users WHERE id = `+c.dialect.Placeholder(0), u.ID)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, untaggedUser.Name, (*string)(nil))
|
|
tt.AssertEqual(t, untaggedUser.Age, 22)
|
|
})
|
|
|
|
t.Run("should have no effect on updates", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
type userWithNoTags struct {
|
|
ID uint `ksql:"id"`
|
|
Name string `ksql:"name"`
|
|
Age int `ksql:"age"`
|
|
}
|
|
untaggedUser := userWithNoTags{
|
|
Name: "Laurinha Ribeiro",
|
|
Age: 11,
|
|
}
|
|
err := c.Insert(ctx, usersTable, &untaggedUser)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, untaggedUser.ID, 0)
|
|
|
|
type taggedUser struct {
|
|
ID uint `ksql:"id"`
|
|
Name string `ksql:"name,skipInserts"`
|
|
Age int `ksql:"age"`
|
|
}
|
|
u := taggedUser{
|
|
ID: untaggedUser.ID,
|
|
Name: "Laura Ribeiro",
|
|
Age: 12,
|
|
}
|
|
err = c.Patch(ctx, usersTable, u)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
var untaggedUser2 userWithNoTags
|
|
err = c.QueryOne(ctx, &untaggedUser2, "FROM users WHERE id = "+c.dialect.Placeholder(0), u.ID)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, untaggedUser2.Name, "Laura Ribeiro")
|
|
tt.AssertEqual(t, untaggedUser2.Age, 12)
|
|
})
|
|
|
|
t.Run("should not alter the value on queries", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
type userWithNoTags struct {
|
|
ID uint `ksql:"id"`
|
|
Name string `ksql:"name"`
|
|
}
|
|
untaggedUser := userWithNoTags{
|
|
Name: "Marta Ribeiro",
|
|
}
|
|
err := c.Insert(ctx, usersTable, &untaggedUser)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, untaggedUser.ID, 0)
|
|
|
|
var taggedUser struct {
|
|
ID uint `ksql:"id"`
|
|
Name string `ksql:"name,skipInserts"`
|
|
}
|
|
err = c.QueryOne(ctx, &taggedUser, "FROM users WHERE id = "+c.dialect.Placeholder(0), untaggedUser.ID)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, taggedUser.ID, untaggedUser.ID)
|
|
tt.AssertEqual(t, taggedUser.Name, "Marta Ribeiro")
|
|
})
|
|
})
|
|
|
|
t.Run("skipUpdates modifier", func(t *testing.T) {
|
|
t.Run("should set the field on insertion", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
type tsUser struct {
|
|
ID uint `ksql:"id"`
|
|
Name string `ksql:"name,skipUpdates"`
|
|
}
|
|
u := tsUser{
|
|
Name: "Letícia",
|
|
}
|
|
err := c.Insert(ctx, usersTable, &u)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, u.ID, 0)
|
|
|
|
var untaggedUser struct {
|
|
ID uint `ksql:"id"`
|
|
Name string `ksql:"name"`
|
|
}
|
|
err = c.QueryOne(ctx, &untaggedUser, `FROM users WHERE id = `+c.dialect.Placeholder(0), u.ID)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, untaggedUser.Name, "Letícia")
|
|
})
|
|
|
|
t.Run("should be ignored on updates", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
type userWithNoTags struct {
|
|
ID uint `ksql:"id"`
|
|
Name string `ksql:"name"`
|
|
Age int `ksql:"age"`
|
|
}
|
|
untaggedUser := userWithNoTags{
|
|
Name: "Laurinha Ribeiro",
|
|
Age: 11,
|
|
}
|
|
err := c.Insert(ctx, usersTable, &untaggedUser)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, untaggedUser.ID, 0)
|
|
|
|
type taggedUser struct {
|
|
ID uint `ksql:"id"`
|
|
Name string `ksql:"name,skipUpdates"`
|
|
Age int `ksql:"age"`
|
|
}
|
|
u := taggedUser{
|
|
ID: untaggedUser.ID,
|
|
Name: "Laura Ribeiro",
|
|
Age: 12,
|
|
}
|
|
err = c.Patch(ctx, usersTable, u)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
var untaggedUser2 userWithNoTags
|
|
err = c.QueryOne(ctx, &untaggedUser2, "FROM users WHERE id = "+c.dialect.Placeholder(0), u.ID)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, untaggedUser2.Name, "Laurinha Ribeiro")
|
|
tt.AssertEqual(t, untaggedUser2.Age, 12)
|
|
})
|
|
|
|
t.Run("should not alter the value on queries", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
type userWithNoTags struct {
|
|
ID uint `ksql:"id"`
|
|
Name string `ksql:"name"`
|
|
}
|
|
untaggedUser := userWithNoTags{
|
|
Name: "Marta Ribeiro",
|
|
}
|
|
err := c.Insert(ctx, usersTable, &untaggedUser)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, untaggedUser.ID, 0)
|
|
|
|
var taggedUser struct {
|
|
ID uint `ksql:"id"`
|
|
Name string `ksql:"name,skipUpdates"`
|
|
}
|
|
err = c.QueryOne(ctx, &taggedUser, "FROM users WHERE id = "+c.dialect.Placeholder(0), untaggedUser.ID)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, taggedUser.ID, untaggedUser.ID)
|
|
tt.AssertEqual(t, taggedUser.Name, "Marta Ribeiro")
|
|
})
|
|
})
|
|
|
|
t.Run("nullable modifier", func(t *testing.T) {
|
|
t.Run("should prevent null fields from being ignored during insertions", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
// The default value of the column "nullable_field"
|
|
// is the string: "not_null".
|
|
//
|
|
// So the tagged struct below should insert passing NULL
|
|
// and the untagged should insert not passing any value
|
|
// for this column, thus, only the second one should create
|
|
// a recording using the default value.
|
|
|
|
var taggedUser struct {
|
|
ID uint `ksql:"id"`
|
|
NullableField *string `ksql:"nullable_field,nullable"`
|
|
}
|
|
|
|
var untaggedUser struct {
|
|
ID uint `ksql:"id"`
|
|
NullableField *string `ksql:"nullable_field"`
|
|
}
|
|
|
|
err := c.Insert(ctx, usersTable, &taggedUser)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, taggedUser.ID, 0)
|
|
|
|
err = c.Insert(ctx, usersTable, &untaggedUser)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, untaggedUser.ID, 0)
|
|
|
|
err = c.QueryOne(ctx, &taggedUser, `FROM users WHERE id = `+c.dialect.Placeholder(0), taggedUser.ID)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
err = c.QueryOne(ctx, &untaggedUser, `FROM users WHERE id = `+c.dialect.Placeholder(0), untaggedUser.ID)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
tt.AssertEqual(t, taggedUser.NullableField == nil, true)
|
|
tt.AssertEqual(t, untaggedUser.NullableField, nullable.String("not_null"))
|
|
})
|
|
|
|
t.Run("should prevent null fields from being ignored during updates", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
type userWithNoTags struct {
|
|
ID uint `ksql:"id"`
|
|
Name string `ksql:"name"`
|
|
NullableField *string `ksql:"nullable_field"`
|
|
}
|
|
untaggedUser := userWithNoTags{
|
|
Name: "Laurinha Ribeiro",
|
|
NullableField: nullable.String("fakeValue"),
|
|
}
|
|
err := c.Insert(ctx, usersTable, &untaggedUser)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, untaggedUser.ID, 0)
|
|
|
|
type taggedUser struct {
|
|
ID uint `ksql:"id"`
|
|
Name string `ksql:"name"`
|
|
NullableField *string `ksql:"nullable_field,nullable"`
|
|
}
|
|
u := taggedUser{
|
|
ID: untaggedUser.ID,
|
|
Name: "Laura Ribeiro",
|
|
NullableField: nil,
|
|
}
|
|
err = c.Patch(ctx, usersTable, u)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
var untaggedUser2 userWithNoTags
|
|
err = c.QueryOne(ctx, &untaggedUser2, "FROM users WHERE id = "+c.dialect.Placeholder(0), untaggedUser.ID)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, untaggedUser2.Name, "Laura Ribeiro")
|
|
tt.AssertEqual(t, untaggedUser2.NullableField == nil, true)
|
|
})
|
|
|
|
t.Run("should not alter the value on queries", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
type userWithNoTags struct {
|
|
ID uint `ksql:"id"`
|
|
Name *string `ksql:"name"`
|
|
}
|
|
untaggedUser := userWithNoTags{
|
|
Name: nullable.String("Marta Ribeiro"),
|
|
}
|
|
err := c.Insert(ctx, usersTable, &untaggedUser)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, untaggedUser.ID, 0)
|
|
|
|
var taggedUser struct {
|
|
ID uint `ksql:"id"`
|
|
Name *string `ksql:"name,nullable"`
|
|
}
|
|
err = c.QueryOne(ctx, &taggedUser, "FROM users WHERE id = "+c.dialect.Placeholder(0), untaggedUser.ID)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, taggedUser.ID, untaggedUser.ID)
|
|
tt.AssertEqual(t, taggedUser.Name, nullable.String("Marta Ribeiro"))
|
|
})
|
|
|
|
t.Run("should cause no effect if used on a non pointer field", func(t *testing.T) {
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
c := newTestDB(db, dialect)
|
|
|
|
type user struct {
|
|
ID uint `ksql:"id"`
|
|
Name string `ksql:"name,nullable"`
|
|
Age int `ksql:"age,nullable"`
|
|
}
|
|
u1 := user{
|
|
Name: "Marta Ribeiro",
|
|
}
|
|
err := c.Insert(ctx, usersTable, &u1)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertNotEqual(t, u1.ID, 0)
|
|
|
|
err = c.Patch(ctx, usersTable, &struct {
|
|
ID uint `ksql:"id"`
|
|
Age int `ksql:"age,nullable"`
|
|
}{
|
|
ID: u1.ID,
|
|
Age: 42,
|
|
})
|
|
tt.AssertNoErr(t, err)
|
|
|
|
var u2 user
|
|
err = c.QueryOne(ctx, &u2, "FROM users WHERE id = "+c.dialect.Placeholder(0), u1.ID)
|
|
tt.AssertNoErr(t, err)
|
|
tt.AssertEqual(t, u2.ID, u1.ID)
|
|
tt.AssertEqual(t, u2.Name, "Marta Ribeiro")
|
|
tt.AssertEqual(t, u2.Age, 42)
|
|
})
|
|
})
|
|
})
|
|
}
|
|
|
|
// ScanRowsTest runs all tests for making sure the ScanRows feature is
|
|
// working for a given adapter and dialect.
|
|
func ScanRowsTest(
|
|
t *testing.T,
|
|
dialect sqldialect.Provider,
|
|
connStr string,
|
|
newDBAdapter func(t *testing.T) (DBAdapter, io.Closer),
|
|
) {
|
|
ctx := context.Background()
|
|
|
|
t.Run("ScanRows", func(t *testing.T) {
|
|
t.Run("should scan users correctly", func(t *testing.T) {
|
|
err := createTables(dialect, connStr)
|
|
if err != nil {
|
|
t.Fatal("could not create test table!, reason:", err.Error())
|
|
}
|
|
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
c := newTestDB(db, dialect)
|
|
_ = c.Insert(ctx, usersTable, &user{Name: "User1", Age: 22})
|
|
_ = c.Insert(ctx, usersTable, &user{Name: "User2", Age: 14})
|
|
_ = c.Insert(ctx, usersTable, &user{Name: "User3", Age: 43})
|
|
|
|
rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User2'")
|
|
tt.AssertNoErr(t, err)
|
|
defer rows.Close()
|
|
|
|
tt.AssertEqual(t, rows.Next(), true)
|
|
|
|
var u user
|
|
err = scanRows(ctx, dialect, rows, &u)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
tt.AssertEqual(t, u.Name, "User2")
|
|
tt.AssertEqual(t, u.Age, 14)
|
|
})
|
|
|
|
t.Run("should ignore extra columns from query", func(t *testing.T) {
|
|
err := createTables(dialect, connStr)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
c := newTestDB(db, dialect)
|
|
_ = c.Insert(ctx, usersTable, &user{Name: "User1", Age: 22})
|
|
|
|
rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User1'")
|
|
tt.AssertNoErr(t, err)
|
|
defer rows.Close()
|
|
|
|
tt.AssertEqual(t, rows.Next(), true)
|
|
|
|
var u struct {
|
|
ID int `ksql:"id"`
|
|
Age int `ksql:"age"`
|
|
|
|
// Omitted for testing purposes:
|
|
// Name string `ksql:"name"`
|
|
}
|
|
err = scanRows(ctx, dialect, rows, &u)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
tt.AssertEqual(t, u.Age, 22)
|
|
})
|
|
|
|
t.Run("should report scan errors", func(t *testing.T) {
|
|
type brokenUser struct {
|
|
ID uint `ksql:"id"`
|
|
|
|
// The error will happen here, when scanning
|
|
// an integer into a attribute of type struct{}:
|
|
Age struct{} `ksql:"age"`
|
|
}
|
|
|
|
type brokenNestedStruct struct {
|
|
User struct {
|
|
ID uint `ksql:"id"`
|
|
|
|
// The error will happen here, when scanning
|
|
// an integer into a attribute of type struct:
|
|
Age struct{} `ksql:"age"`
|
|
} `tablename:"u"`
|
|
Post post `tablename:"p"`
|
|
}
|
|
|
|
tests := []struct {
|
|
desc string
|
|
query string
|
|
scanTarget interface{}
|
|
expectErrToContain []string
|
|
}{
|
|
{
|
|
desc: "with anonymous structs",
|
|
query: "FROM users WHERE name='User22'",
|
|
scanTarget: &struct {
|
|
ID uint `ksql:"id"`
|
|
|
|
// The error will happen here, when scanning
|
|
// an integer into a attribute of type struct{}:
|
|
Age struct{} `ksql:"age"`
|
|
}{},
|
|
expectErrToContain: []string{" .Age", "struct {}"},
|
|
},
|
|
{
|
|
desc: "with named structs",
|
|
query: "FROM users WHERE name='User22'",
|
|
scanTarget: &brokenUser{},
|
|
expectErrToContain: []string{"brokenUser.Age", "struct {}"},
|
|
},
|
|
{
|
|
desc: "with anonymous nested structs",
|
|
query: "FROM users u JOIN posts p ON u.id = p.user_id WHERE name='User22'",
|
|
scanTarget: &struct {
|
|
User struct {
|
|
ID uint `ksql:"id"`
|
|
|
|
// The error will happen here, when scanning
|
|
// an integer into a attribute of type struct:
|
|
Age struct{} `ksql:"age"`
|
|
} `tablename:"u"`
|
|
Post post `tablename:"p"`
|
|
}{},
|
|
expectErrToContain: []string{".User.Age", "struct {}"},
|
|
},
|
|
{
|
|
desc: "with named nested structs",
|
|
query: "FROM users u JOIN posts p ON u.id = p.user_id WHERE name='User22'",
|
|
scanTarget: &brokenNestedStruct{},
|
|
expectErrToContain: []string{"brokenNestedStruct.User.Age", "struct {}"},
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.desc, func(t *testing.T) {
|
|
err := createTables(dialect, connStr)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
c := newTestDB(db, dialect)
|
|
|
|
u := user{Name: "User22", Age: 22}
|
|
_ = c.Insert(ctx, usersTable, &u)
|
|
_ = c.Insert(ctx, postsTable, &post{UserID: u.ID, Title: "FakeTitle"})
|
|
|
|
query := mustBuildSelectQuery(t, dialect, test.scanTarget, test.query)
|
|
|
|
rows, err := db.QueryContext(ctx, query)
|
|
tt.AssertNoErr(t, err)
|
|
defer rows.Close()
|
|
|
|
tt.AssertEqual(t, rows.Next(), true)
|
|
|
|
err = scanRows(ctx, dialect, rows, test.scanTarget)
|
|
tt.AssertErrContains(t, err, test.expectErrToContain...)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("should report error for closed rows", func(t *testing.T) {
|
|
err := createTables(dialect, connStr)
|
|
if err != nil {
|
|
t.Fatal("could not create test table!, reason:", err.Error())
|
|
}
|
|
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User2'")
|
|
tt.AssertNoErr(t, err)
|
|
|
|
var u user
|
|
err = rows.Close()
|
|
tt.AssertNoErr(t, err)
|
|
err = scanRows(ctx, dialect, rows, &u)
|
|
tt.AssertNotEqual(t, err, nil)
|
|
})
|
|
|
|
t.Run("should report if record is not a pointer", func(t *testing.T) {
|
|
err := createTables(dialect, connStr)
|
|
if err != nil {
|
|
t.Fatal("could not create test table!, reason:", err.Error())
|
|
}
|
|
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User2'")
|
|
tt.AssertNoErr(t, err)
|
|
defer rows.Close()
|
|
|
|
var u user
|
|
err = scanRows(ctx, dialect, rows, u)
|
|
tt.AssertErrContains(t, err, "ksql", "expected", "pointer to struct", "user")
|
|
})
|
|
|
|
t.Run("should report if record is not a pointer to struct", func(t *testing.T) {
|
|
err := createTables(dialect, connStr)
|
|
if err != nil {
|
|
t.Fatal("could not create test table!, reason:", err.Error())
|
|
}
|
|
|
|
db, closer := newDBAdapter(t)
|
|
defer closer.Close()
|
|
|
|
rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User2'")
|
|
tt.AssertNoErr(t, err)
|
|
defer rows.Close()
|
|
|
|
var u map[string]interface{}
|
|
err = scanRows(ctx, dialect, rows, &u)
|
|
tt.AssertErrContains(t, err, "KSQL", "expected", "pointer to struct", "map[string]interface")
|
|
})
|
|
})
|
|
}
|
|
|
|
func createTables(dialect sqldialect.Provider, connStr string) error {
|
|
driver := dialect.DriverName()
|
|
|
|
if connStr == "" {
|
|
return fmt.Errorf("unsupported dialect: '%s'", driver)
|
|
}
|
|
|
|
db, err := sql.Open(driver, connStr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer db.Close()
|
|
|
|
db.Exec(`DROP TABLE users`)
|
|
|
|
switch driver {
|
|
case "sqlite3":
|
|
_, err = db.Exec(`CREATE TABLE users (
|
|
id INTEGER PRIMARY KEY,
|
|
age INTEGER,
|
|
name TEXT,
|
|
address BLOB,
|
|
created_at DATETIME,
|
|
updated_at DATETIME,
|
|
nullable_field TEXT DEFAULT "not_null"
|
|
)`)
|
|
case "postgres":
|
|
_, err = db.Exec(`CREATE TABLE users (
|
|
id serial PRIMARY KEY,
|
|
age INT,
|
|
name VARCHAR(50),
|
|
address jsonb,
|
|
created_at TIMESTAMP,
|
|
updated_at TIMESTAMP,
|
|
nullable_field VARCHAR(50) DEFAULT 'not_null'
|
|
)`)
|
|
case "mysql":
|
|
_, err = db.Exec(`CREATE TABLE users (
|
|
id INT AUTO_INCREMENT PRIMARY KEY,
|
|
age INT,
|
|
name VARCHAR(50),
|
|
address JSON,
|
|
created_at DATETIME,
|
|
updated_at DATETIME,
|
|
nullable_field VARCHAR(50) DEFAULT "not_null"
|
|
)`)
|
|
case "sqlserver":
|
|
_, err = db.Exec(`CREATE TABLE users (
|
|
id INT IDENTITY(1,1) PRIMARY KEY,
|
|
age INT,
|
|
name VARCHAR(50),
|
|
address NVARCHAR(4000),
|
|
created_at DATETIME,
|
|
updated_at DATETIME,
|
|
nullable_field VARCHAR(50) DEFAULT 'not_null'
|
|
)`)
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create new users table: %s", err.Error())
|
|
}
|
|
|
|
db.Exec(`DROP TABLE posts`)
|
|
|
|
switch driver {
|
|
case "sqlite3":
|
|
_, err = db.Exec(`CREATE TABLE posts (
|
|
id INTEGER PRIMARY KEY,
|
|
user_id INTEGER,
|
|
title TEXT
|
|
)`)
|
|
case "postgres":
|
|
_, err = db.Exec(`CREATE TABLE posts (
|
|
id serial PRIMARY KEY,
|
|
user_id INT,
|
|
title VARCHAR(50)
|
|
)`)
|
|
case "mysql":
|
|
_, err = db.Exec(`CREATE TABLE posts (
|
|
id INT AUTO_INCREMENT PRIMARY KEY,
|
|
user_id INT,
|
|
title VARCHAR(50)
|
|
)`)
|
|
case "sqlserver":
|
|
_, err = db.Exec(`CREATE TABLE posts (
|
|
id INT IDENTITY(1,1) PRIMARY KEY,
|
|
user_id INT,
|
|
title VARCHAR(50)
|
|
)`)
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create new posts table: %s", err.Error())
|
|
}
|
|
|
|
db.Exec(`DROP TABLE user_permissions`)
|
|
|
|
switch driver {
|
|
case "sqlite3":
|
|
_, err = db.Exec(`CREATE TABLE user_permissions (
|
|
id INTEGER PRIMARY KEY,
|
|
user_id INTEGER,
|
|
perm_id INTEGER,
|
|
type TEXT,
|
|
UNIQUE (user_id, perm_id)
|
|
)`)
|
|
case "postgres":
|
|
_, err = db.Exec(`CREATE TABLE user_permissions (
|
|
id serial PRIMARY KEY,
|
|
user_id INT,
|
|
perm_id INT,
|
|
type VARCHAR(50),
|
|
UNIQUE (user_id, perm_id)
|
|
)`)
|
|
case "mysql":
|
|
_, err = db.Exec(`CREATE TABLE user_permissions (
|
|
id INT AUTO_INCREMENT PRIMARY KEY,
|
|
user_id INT,
|
|
perm_id INT,
|
|
type VARCHAR(50),
|
|
UNIQUE KEY (user_id, perm_id)
|
|
)`)
|
|
case "sqlserver":
|
|
_, err = db.Exec(`CREATE TABLE user_permissions (
|
|
id INT IDENTITY(1,1) PRIMARY KEY,
|
|
user_id INT,
|
|
perm_id INT,
|
|
type VARCHAR(50),
|
|
CONSTRAINT unique_1 UNIQUE (user_id, perm_id)
|
|
)`)
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create new user_permissions table: %s", err.Error())
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func newTestDB(db DBAdapter, dialect sqldialect.Provider) DB {
|
|
return DB{
|
|
dialect: dialect,
|
|
db: db,
|
|
}
|
|
}
|
|
|
|
func shiftErrSlice(errs *[]error) error {
|
|
err := (*errs)[0]
|
|
*errs = (*errs)[1:]
|
|
return err
|
|
}
|
|
|
|
func getUserByID(db DBAdapter, dialect sqldialect.Provider, result *user, id uint) error {
|
|
rows, err := db.QueryContext(context.TODO(), `SELECT id, name, age, address FROM users WHERE id=`+dialect.Placeholder(0), id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
|
|
if rows.Next() == false {
|
|
if rows.Err() != nil {
|
|
return rows.Err()
|
|
}
|
|
return sql.ErrNoRows
|
|
}
|
|
|
|
modifier, _ := modifiers.LoadGlobalModifier("json")
|
|
|
|
value := modifiers.AttrScanWrapper{
|
|
Ctx: context.TODO(),
|
|
AttrPtr: &result.Address,
|
|
ScanFn: modifier.Scan,
|
|
OpInfo: ksqlmodifiers.OpInfo{
|
|
DriverName: dialect.DriverName(),
|
|
// We will not differentiate between Query, QueryOne and QueryChunks
|
|
// if we did this could lead users to make very strange modifiers
|
|
Method: "Query",
|
|
},
|
|
}
|
|
|
|
err = rows.Scan(&result.ID, &result.Name, &result.Age, &value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func getUserByName(db DBAdapter, dialect sqldialect.Provider, result *user, name string) error {
|
|
rows, err := db.QueryContext(context.TODO(), `SELECT id, name, age, address FROM users WHERE name=`+dialect.Placeholder(0), name)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
|
|
if rows.Next() == false {
|
|
if rows.Err() != nil {
|
|
return rows.Err()
|
|
}
|
|
return sql.ErrNoRows
|
|
}
|
|
|
|
var rawAddr []byte
|
|
err = rows.Scan(&result.ID, &result.Name, &result.Age, &rawAddr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if rawAddr == nil {
|
|
return nil
|
|
}
|
|
|
|
return json.Unmarshal(rawAddr, &result.Address)
|
|
}
|
|
|
|
func createUserPermission(db DBAdapter, dialect sqldialect.Provider, userPerm userPermission) error {
|
|
_, err := db.ExecContext(context.TODO(),
|
|
`INSERT INTO user_permissions (user_id, perm_id, type)
|
|
VALUES (`+dialect.Placeholder(0)+`, `+dialect.Placeholder(1)+`, `+dialect.Placeholder(2)+`)`,
|
|
userPerm.UserID, userPerm.PermID, userPerm.Type,
|
|
)
|
|
|
|
return err
|
|
}
|
|
|
|
func getUserPermissionBySecondaryKeys(db DBAdapter, dialect sqldialect.Provider, userID int, permID int) (userPermission, error) {
|
|
rows, err := db.QueryContext(context.TODO(),
|
|
`SELECT id, user_id, perm_id, type FROM user_permissions WHERE user_id=`+dialect.Placeholder(0)+` AND perm_id=`+dialect.Placeholder(1),
|
|
userID, permID,
|
|
)
|
|
if err != nil {
|
|
return userPermission{}, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
if rows.Next() == false {
|
|
if rows.Err() != nil {
|
|
return userPermission{}, rows.Err()
|
|
}
|
|
return userPermission{}, sql.ErrNoRows
|
|
}
|
|
|
|
var result userPermission
|
|
err = rows.Scan(&result.ID, &result.UserID, &result.PermID, &result.Type)
|
|
if err != nil {
|
|
return userPermission{}, err
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func getUserPermissionsByUser(db DBAdapter, dialect sqldialect.Provider, userID int) (results []userPermission, _ error) {
|
|
rows, err := db.QueryContext(context.TODO(),
|
|
`SELECT id, user_id, perm_id FROM user_permissions WHERE user_id=`+dialect.Placeholder(0),
|
|
userID,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var userPerm userPermission
|
|
err := rows.Scan(&userPerm.ID, &userPerm.UserID, &userPerm.PermID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
results = append(results, userPerm)
|
|
}
|
|
if rows.Err() != nil {
|
|
return nil, rows.Err()
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
func mustBuildSelectQuery(t *testing.T,
|
|
dialect sqldialect.Provider,
|
|
record interface{},
|
|
query string,
|
|
) string {
|
|
if strings.HasPrefix(query, "SELECT") {
|
|
return query
|
|
}
|
|
|
|
structType := reflect.TypeOf(record).Elem()
|
|
structInfo, err := structs.GetTagInfo(structType)
|
|
tt.AssertNoErr(t, err)
|
|
|
|
selectPrefix, err := buildSelectQuery(dialect, structType, structInfo, selectQueryCache[dialect.DriverName()])
|
|
tt.AssertNoErr(t, err)
|
|
|
|
return selectPrefix + query
|
|
}
|