Add feature of nesting structs so we can reuse existing structs

pull/2/head
Vinícius Garcia 2021-05-23 11:28:16 -03:00
parent d8ca3cab8d
commit 0d3a75fe42
3 changed files with 261 additions and 60 deletions

100
ksql.go
View File

@ -813,11 +813,6 @@ func (nopScanner) Scan(value interface{}) error {
}
func scanRows(dialect dialect, rows *sql.Rows, record interface{}) error {
names, err := rows.Columns()
if err != nil {
return err
}
v := reflect.ValueOf(record)
t := v.Type()
if t.Kind() != reflect.Ptr {
@ -833,6 +828,53 @@ func scanRows(dialect dialect, rows *sql.Rows, record interface{}) error {
info := structs.GetTagInfo(t)
var scanArgs []interface{}
if info.IsNestedStruct {
// This version is positional meaning that it expect the arguments
// to follow an specific order. It's ok because we don't allow the
// user to type the "SELECT" part of the query for nested structs.
scanArgs = getScanArgsForNestedStructs(dialect, rows, t, v, info)
} else {
names, err := rows.Columns()
if err != nil {
return err
}
// Since this version uses the names of the columns it works
// with any order of attributes/columns.
scanArgs = getScanArgsFromNames(dialect, names, v, info)
}
return rows.Scan(scanArgs...)
}
func getScanArgsForNestedStructs(dialect dialect, rows *sql.Rows, t reflect.Type, v reflect.Value, info structs.StructInfo) []interface{} {
scanArgs := []interface{}{}
for i := 0; i < v.NumField(); i++ {
// TODO(vingarcia00): Handle case where type is pointer
nestedStructInfo := structs.GetTagInfo(t.Field(i).Type)
nestedStructValue := v.Field(i)
for j := 0; j < nestedStructValue.NumField(); j++ {
fieldInfo := nestedStructInfo.ByIndex(j)
valueScanner := nopScannerValue
if fieldInfo.Valid {
valueScanner = nestedStructValue.Field(fieldInfo.Index).Addr().Interface()
if fieldInfo.SerializeAsJSON {
valueScanner = &jsonSerializable{
DriverName: dialect.DriverName(),
Attr: valueScanner,
}
}
}
scanArgs = append(scanArgs, valueScanner)
}
}
return scanArgs
}
func getScanArgsFromNames(dialect dialect, names []string, v reflect.Value, info structs.StructInfo) []interface{} {
scanArgs := []interface{}{}
for _, name := range names {
fieldInfo := info.ByName(name)
@ -851,7 +893,7 @@ func scanRows(dialect dialect, rows *sql.Rows, record interface{}) error {
scanArgs = append(scanArgs, valueScanner)
}
return rows.Scan(scanArgs...)
return scanArgs
}
func buildSingleKeyDeleteQuery(
@ -923,18 +965,54 @@ func buildSelectQuery(
dialect dialect,
structType reflect.Type,
selectQueryCache map[reflect.Type]string,
) (string, error) {
) (query string, err error) {
if selectQuery, found := selectQueryCache[structType]; found {
return selectQuery, nil
}
info := structs.GetTagInfo(structType)
var fields []string
for _, field := range info.Fields() {
fields = append(fields, dialect.Escape(field.Name))
if info.IsNestedStruct {
query, err = buildSelectQueryForNestedStructs(dialect, structType, info)
if err != nil {
return "", err
}
} else {
query = buildSelectQueryForPlainStructs(dialect, structType, info)
}
query := "SELECT " + strings.Join(fields, ", ") + " "
selectQueryCache[structType] = query
return query, nil
}
func buildSelectQueryForPlainStructs(
dialect dialect,
structType reflect.Type,
info structs.StructInfo,
) string {
var fields []string
for i := 0; i < structType.NumField(); i++ {
fields = append(fields, dialect.Escape(info.ByIndex(i).Name))
}
return "SELECT " + strings.Join(fields, ", ") + " "
}
func buildSelectQueryForNestedStructs(
dialect dialect,
structType reflect.Type,
info structs.StructInfo,
) (string, error) {
var fields []string
for i := 0; i < structType.NumField(); i++ {
nestedStructName := info.ByIndex(i).Name
nestedStructInfo := structs.GetTagInfo(structType.Field(i).Type)
for j := 0; j < structType.Field(i).Type.NumField(); j++ {
fields = append(
fields,
dialect.Escape(nestedStructName)+"."+dialect.Escape(nestedStructInfo.ByIndex(j).Name),
)
}
}
return "SELECT " + strings.Join(fields, ", ") + " ", nil
}

View File

@ -34,6 +34,12 @@ type Address struct {
Country string `json:"country"`
}
type Post struct {
ID int `ksql:"id"`
UserID uint `ksql:"user_id"`
Title string `ksql:"title"`
}
func TestQuery(t *testing.T) {
for driver := range supportedDialects {
t.Run(driver, func(t *testing.T) {
@ -53,7 +59,7 @@ func TestQuery(t *testing.T) {
for _, variation := range variations {
t.Run(variation.desc, func(t *testing.T) {
t.Run("using slice of structs", func(t *testing.T) {
err := createTable(driver)
err := createTables(driver)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
@ -123,7 +129,7 @@ func TestQuery(t *testing.T) {
})
t.Run("using slice of pointers to structs", func(t *testing.T) {
err := createTable(driver)
err := createTables(driver)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
@ -190,12 +196,66 @@ func TestQuery(t *testing.T) {
assert.Equal(t, "Bia Garcia", users[1].Name)
assert.Equal(t, "BR", users[1].Address.Country)
})
t.Run("should query joined tables correctly", func(t *testing.T) {
db := connectDB(t, driver)
defer db.Close()
// This test only makes sense with no query prefix
if variation.queryPrefix != "" {
return
}
_, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('João Ribeiro', 0, '{"country":"US"}')`)
assert.Equal(t, nil, err)
var joaoID uint
db.QueryRow(`SELECT id FROM users WHERE name = 'João Ribeiro'`).Scan(&joaoID)
_, err = db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia Ribeiro', 0, '{"country":"BR"}')`)
assert.Equal(t, nil, err)
var biaID uint
db.QueryRow(`SELECT id FROM users WHERE name = 'Bia Ribeiro'`).Scan(&biaID)
_, err = db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, biaID, `, 'Bia Post1')`))
assert.Equal(t, nil, err)
_, err = db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, biaID, `, 'Bia Post2')`))
assert.Equal(t, nil, err)
_, err = db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, joaoID, `, 'João Post1')`))
assert.Equal(t, nil, err)
ctx := context.Background()
c := newTestDB(db, driver, "users")
var rows []*struct {
User User `tablename:"u"`
Post Post `tablename:"p"`
}
err = c.Query(ctx, &rows, fmt.Sprint(
`FROM users u JOIN posts p ON p.user_id = u.id`,
` WHERE u.name like `, c.dialect.Placeholder(0),
` ORDER BY u.id, p.id`,
), "% Ribeiro")
assert.Equal(t, nil, err)
assert.Equal(t, 3, len(rows))
assert.Equal(t, joaoID, rows[0].User.ID)
assert.Equal(t, "João Ribeiro", rows[0].User.Name)
assert.Equal(t, "João Post1", rows[0].Post.Title)
assert.Equal(t, biaID, rows[1].User.ID)
assert.Equal(t, "Bia Ribeiro", rows[1].User.Name)
assert.Equal(t, "Bia Post1", rows[1].Post.Title)
assert.Equal(t, biaID, rows[2].User.ID)
assert.Equal(t, "Bia Ribeiro", rows[2].User.Name)
assert.Equal(t, "Bia Post2", rows[2].Post.Title)
})
})
})
}
t.Run("testing error cases", func(t *testing.T) {
err := createTable(driver)
err := createTables(driver)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
@ -258,7 +318,7 @@ func TestQueryOne(t *testing.T) {
},
}
for _, variation := range variations {
err := createTable(driver)
err := createTables(driver)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
@ -358,7 +418,7 @@ func TestInsert(t *testing.T) {
for driver := range supportedDialects {
t.Run(driver, func(t *testing.T) {
t.Run("using slice of structs", func(t *testing.T) {
err := createTable(driver)
err := createTables(driver)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
@ -428,7 +488,7 @@ func TestInsert(t *testing.T) {
})
t.Run("testing error cases", func(t *testing.T) {
err := createTable(driver)
err := createTables(driver)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
@ -485,7 +545,7 @@ func TestInsert(t *testing.T) {
func TestDelete(t *testing.T) {
for driver := range supportedDialects {
t.Run(driver, func(t *testing.T) {
err := createTable(driver)
err := createTables(driver)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
@ -628,7 +688,7 @@ func TestDelete(t *testing.T) {
func TestUpdate(t *testing.T) {
for driver := range supportedDialects {
t.Run(driver, func(t *testing.T) {
err := createTable(driver)
err := createTables(driver)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
@ -813,7 +873,7 @@ func TestQueryChunks(t *testing.T) {
for _, variation := range variations {
t.Run(variation.desc, func(t *testing.T) {
t.Run("should query a single row correctly", func(t *testing.T) {
err := createTable(driver)
err := createTables(driver)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
@ -853,7 +913,7 @@ func TestQueryChunks(t *testing.T) {
})
t.Run("should query one chunk correctly", func(t *testing.T) {
err := createTable(driver)
err := createTables(driver)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
@ -895,7 +955,7 @@ func TestQueryChunks(t *testing.T) {
})
t.Run("should query chunks of 1 correctly", func(t *testing.T) {
err := createTable(driver)
err := createTables(driver)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
@ -937,7 +997,7 @@ func TestQueryChunks(t *testing.T) {
})
t.Run("should load partially filled chunks correctly", func(t *testing.T) {
err := createTable(driver)
err := createTables(driver)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
@ -978,7 +1038,7 @@ func TestQueryChunks(t *testing.T) {
})
t.Run("should abort the first iteration when the callback returns an ErrAbortIteration", func(t *testing.T) {
err := createTable(driver)
err := createTables(driver)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
@ -1017,7 +1077,7 @@ func TestQueryChunks(t *testing.T) {
})
t.Run("should abort the last iteration when the callback returns an ErrAbortIteration", func(t *testing.T) {
err := createTable(driver)
err := createTables(driver)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
@ -1060,7 +1120,7 @@ func TestQueryChunks(t *testing.T) {
})
t.Run("should return error if the callback returns an error in the first iteration", func(t *testing.T) {
err := createTable(driver)
err := createTables(driver)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
@ -1099,7 +1159,7 @@ func TestQueryChunks(t *testing.T) {
})
t.Run("should return error if the callback returns an error in the last iteration", func(t *testing.T) {
err := createTable(driver)
err := createTables(driver)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
@ -1213,7 +1273,7 @@ func TestTransaction(t *testing.T) {
for driver := range supportedDialects {
t.Run(driver, func(t *testing.T) {
t.Run("should query a single row correctly", func(t *testing.T) {
err := createTable(driver)
err := createTables(driver)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
@ -1240,7 +1300,7 @@ func TestTransaction(t *testing.T) {
})
t.Run("should rollback when there are errors", func(t *testing.T) {
err := createTable(driver)
err := createTables(driver)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
@ -1281,7 +1341,7 @@ func TestTransaction(t *testing.T) {
func TestScanRows(t *testing.T) {
t.Run("should scan users correctly", func(t *testing.T) {
err := createTable("sqlite3")
err := createTables("sqlite3")
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
@ -1310,7 +1370,7 @@ func TestScanRows(t *testing.T) {
})
t.Run("should ignore extra columns from query", func(t *testing.T) {
err := createTable("sqlite3")
err := createTables("sqlite3")
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
@ -1342,7 +1402,7 @@ func TestScanRows(t *testing.T) {
})
t.Run("should report error for closed rows", func(t *testing.T) {
err := createTable("sqlite3")
err := createTables("sqlite3")
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
@ -1363,7 +1423,7 @@ func TestScanRows(t *testing.T) {
})
t.Run("should report if record is not a pointer", func(t *testing.T) {
err := createTable("sqlite3")
err := createTables("sqlite3")
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
@ -1382,7 +1442,7 @@ func TestScanRows(t *testing.T) {
})
t.Run("should report if record is not a pointer to struct", func(t *testing.T) {
err := createTable("sqlite3")
err := createTables("sqlite3")
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
@ -1408,7 +1468,7 @@ var connectionString = map[string]string{
"sqlserver": "sqlserver://sa:Sqls3rv3r@127.0.0.1:1433?databaseName=ksql",
}
func createTable(driver string) error {
func createTables(driver string) error {
connStr := connectionString[driver]
if connStr == "" {
return fmt.Errorf("unsupported driver: '%s'", driver)
@ -1456,6 +1516,38 @@ func createTable(driver string) error {
return fmt.Errorf("failed to create new users table: %s", err.Error())
}
db.Exec(`DROP TABLE posts`)
switch driver {
case "sqlite3":
_, err = db.Exec(`CREATE TABLE posts (
id INTEGER PRIMARY KEY,
user_id INTEGER,
title TEXT
)`)
case "postgres":
_, err = db.Exec(`CREATE TABLE posts (
id serial PRIMARY KEY,
user_id INT,
title VARCHAR(50)
)`)
case "mysql":
_, err = db.Exec(`CREATE TABLE posts (
id INT AUTO_INCREMENT PRIMARY KEY,
user_id INT,
title VARCHAR(50)
)`)
case "sqlserver":
_, err = db.Exec(`CREATE TABLE posts (
id INT IDENTITY(1,1) PRIMARY KEY,
user_id INT,
title VARCHAR(50)
)`)
}
if err != nil {
return fmt.Errorf("failed to create new users table: %s", err.Error())
}
return nil
}

View File

@ -8,49 +8,56 @@ import (
"github.com/pkg/errors"
)
type structInfo struct {
byIndex map[int]*fieldInfo
byName map[string]*fieldInfo
// StructInfo stores metainformation of the struct
// parser in order to help the ksql library to work
// efectively and efficiently with reflection.
type StructInfo struct {
IsNestedStruct bool
byIndex map[int]*FieldInfo
byName map[string]*FieldInfo
}
type fieldInfo struct {
// FieldInfo contains reflection and tags
// information regarding a specific field
// of a struct.
type FieldInfo struct {
Name string
Index int
Valid bool
SerializeAsJSON bool
}
func (s structInfo) ByIndex(idx int) *fieldInfo {
// ByIndex returns either the *FieldInfo of a valid
// empty struct with Valid set to false
func (s StructInfo) ByIndex(idx int) *FieldInfo {
field, found := s.byIndex[idx]
if !found {
return &fieldInfo{}
return &FieldInfo{}
}
return field
}
func (s structInfo) ByName(name string) *fieldInfo {
// ByName returns either the *FieldInfo of a valid
// empty struct with Valid set to false
func (s StructInfo) ByName(name string) *FieldInfo {
field, found := s.byName[name]
if !found {
return &fieldInfo{}
return &FieldInfo{}
}
return field
}
func (s structInfo) Add(field fieldInfo) {
func (s StructInfo) add(field FieldInfo) {
field.Valid = true
s.byIndex[field.Index] = &field
s.byName[field.Name] = &field
}
func (s structInfo) Fields() map[int]*fieldInfo {
return s.byIndex
}
// This cache is kept as a pkg variable
// because the total number of types on a program
// should be finite. So keeping a single cache here
// works fine.
var tagInfoCache = map[reflect.Type]structInfo{}
var tagInfoCache = map[reflect.Type]StructInfo{}
// GetTagInfo efficiently returns the type information
// using a global private cache
@ -58,16 +65,17 @@ var tagInfoCache = map[reflect.Type]structInfo{}
// In the future we might move this cache inside
// a struct, but for now this accessor is the one
// we are using
func GetTagInfo(key reflect.Type) structInfo {
func GetTagInfo(key reflect.Type) StructInfo {
return getCachedTagInfo(tagInfoCache, key)
}
func getCachedTagInfo(tagInfoCache map[reflect.Type]structInfo, key reflect.Type) structInfo {
info, found := tagInfoCache[key]
if !found {
info = getTagNames(key)
tagInfoCache[key] = info
func getCachedTagInfo(tagInfoCache map[reflect.Type]StructInfo, key reflect.Type) StructInfo {
if info, found := tagInfoCache[key]; found {
return info
}
info := getTagNames(key)
tagInfoCache[key] = info
return info
}
@ -291,10 +299,10 @@ func FillSliceWith(entities interface{}, dbRows []map[string]interface{}) error
//
// This should save several calls to `Field(i).Tag.Get("foo")`
// which improves performance by a lot.
func getTagNames(t reflect.Type) structInfo {
info := structInfo{
byIndex: map[int]*fieldInfo{},
byName: map[string]*fieldInfo{},
func getTagNames(t reflect.Type) StructInfo {
info := StructInfo{
byIndex: map[int]*FieldInfo{},
byName: map[string]*FieldInfo{},
}
for i := 0; i < t.NumField(); i++ {
name := t.Field(i).Tag.Get("ksql")
@ -309,13 +317,36 @@ func getTagNames(t reflect.Type) structInfo {
serializeAsJSON = tags[1] == "json"
}
info.Add(fieldInfo{
info.add(FieldInfo{
Name: name,
Index: i,
SerializeAsJSON: serializeAsJSON,
})
}
// If there were `ksql` tags present, then we are finished:
if len(info.byIndex) > 0 {
return info
}
// If there are no `ksql` tags in the struct, lets assume
// it is a struct tagged with `tablename` for allowing JOINs
for i := 0; i < t.NumField(); i++ {
name := t.Field(i).Tag.Get("tablename")
if name == "" {
continue
}
info.add(FieldInfo{
Name: name,
Index: i,
})
}
if len(info.byIndex) > 0 {
info.IsNestedStruct = true
}
return info
}