mirror of https://github.com/VinGarcia/ksql.git
Add feature of nesting structs so we can reuse existing structs
parent
d8ca3cab8d
commit
0d3a75fe42
100
ksql.go
100
ksql.go
|
@ -813,11 +813,6 @@ func (nopScanner) Scan(value interface{}) error {
|
|||
}
|
||||
|
||||
func scanRows(dialect dialect, rows *sql.Rows, record interface{}) error {
|
||||
names, err := rows.Columns()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v := reflect.ValueOf(record)
|
||||
t := v.Type()
|
||||
if t.Kind() != reflect.Ptr {
|
||||
|
@ -833,6 +828,53 @@ func scanRows(dialect dialect, rows *sql.Rows, record interface{}) error {
|
|||
|
||||
info := structs.GetTagInfo(t)
|
||||
|
||||
var scanArgs []interface{}
|
||||
if info.IsNestedStruct {
|
||||
// This version is positional meaning that it expect the arguments
|
||||
// to follow an specific order. It's ok because we don't allow the
|
||||
// user to type the "SELECT" part of the query for nested structs.
|
||||
scanArgs = getScanArgsForNestedStructs(dialect, rows, t, v, info)
|
||||
} else {
|
||||
names, err := rows.Columns()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Since this version uses the names of the columns it works
|
||||
// with any order of attributes/columns.
|
||||
scanArgs = getScanArgsFromNames(dialect, names, v, info)
|
||||
}
|
||||
|
||||
return rows.Scan(scanArgs...)
|
||||
}
|
||||
|
||||
func getScanArgsForNestedStructs(dialect dialect, rows *sql.Rows, t reflect.Type, v reflect.Value, info structs.StructInfo) []interface{} {
|
||||
scanArgs := []interface{}{}
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
// TODO(vingarcia00): Handle case where type is pointer
|
||||
nestedStructInfo := structs.GetTagInfo(t.Field(i).Type)
|
||||
nestedStructValue := v.Field(i)
|
||||
for j := 0; j < nestedStructValue.NumField(); j++ {
|
||||
fieldInfo := nestedStructInfo.ByIndex(j)
|
||||
|
||||
valueScanner := nopScannerValue
|
||||
if fieldInfo.Valid {
|
||||
valueScanner = nestedStructValue.Field(fieldInfo.Index).Addr().Interface()
|
||||
if fieldInfo.SerializeAsJSON {
|
||||
valueScanner = &jsonSerializable{
|
||||
DriverName: dialect.DriverName(),
|
||||
Attr: valueScanner,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
scanArgs = append(scanArgs, valueScanner)
|
||||
}
|
||||
}
|
||||
|
||||
return scanArgs
|
||||
}
|
||||
|
||||
func getScanArgsFromNames(dialect dialect, names []string, v reflect.Value, info structs.StructInfo) []interface{} {
|
||||
scanArgs := []interface{}{}
|
||||
for _, name := range names {
|
||||
fieldInfo := info.ByName(name)
|
||||
|
@ -851,7 +893,7 @@ func scanRows(dialect dialect, rows *sql.Rows, record interface{}) error {
|
|||
scanArgs = append(scanArgs, valueScanner)
|
||||
}
|
||||
|
||||
return rows.Scan(scanArgs...)
|
||||
return scanArgs
|
||||
}
|
||||
|
||||
func buildSingleKeyDeleteQuery(
|
||||
|
@ -923,18 +965,54 @@ func buildSelectQuery(
|
|||
dialect dialect,
|
||||
structType reflect.Type,
|
||||
selectQueryCache map[reflect.Type]string,
|
||||
) (string, error) {
|
||||
) (query string, err error) {
|
||||
if selectQuery, found := selectQueryCache[structType]; found {
|
||||
return selectQuery, nil
|
||||
}
|
||||
|
||||
info := structs.GetTagInfo(structType)
|
||||
var fields []string
|
||||
for _, field := range info.Fields() {
|
||||
fields = append(fields, dialect.Escape(field.Name))
|
||||
if info.IsNestedStruct {
|
||||
query, err = buildSelectQueryForNestedStructs(dialect, structType, info)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
} else {
|
||||
query = buildSelectQueryForPlainStructs(dialect, structType, info)
|
||||
}
|
||||
|
||||
query := "SELECT " + strings.Join(fields, ", ") + " "
|
||||
selectQueryCache[structType] = query
|
||||
return query, nil
|
||||
}
|
||||
|
||||
func buildSelectQueryForPlainStructs(
|
||||
dialect dialect,
|
||||
structType reflect.Type,
|
||||
info structs.StructInfo,
|
||||
) string {
|
||||
var fields []string
|
||||
for i := 0; i < structType.NumField(); i++ {
|
||||
fields = append(fields, dialect.Escape(info.ByIndex(i).Name))
|
||||
}
|
||||
|
||||
return "SELECT " + strings.Join(fields, ", ") + " "
|
||||
}
|
||||
|
||||
func buildSelectQueryForNestedStructs(
|
||||
dialect dialect,
|
||||
structType reflect.Type,
|
||||
info structs.StructInfo,
|
||||
) (string, error) {
|
||||
var fields []string
|
||||
for i := 0; i < structType.NumField(); i++ {
|
||||
nestedStructName := info.ByIndex(i).Name
|
||||
nestedStructInfo := structs.GetTagInfo(structType.Field(i).Type)
|
||||
for j := 0; j < structType.Field(i).Type.NumField(); j++ {
|
||||
fields = append(
|
||||
fields,
|
||||
dialect.Escape(nestedStructName)+"."+dialect.Escape(nestedStructInfo.ByIndex(j).Name),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return "SELECT " + strings.Join(fields, ", ") + " ", nil
|
||||
}
|
||||
|
|
140
ksql_test.go
140
ksql_test.go
|
@ -34,6 +34,12 @@ type Address struct {
|
|||
Country string `json:"country"`
|
||||
}
|
||||
|
||||
type Post struct {
|
||||
ID int `ksql:"id"`
|
||||
UserID uint `ksql:"user_id"`
|
||||
Title string `ksql:"title"`
|
||||
}
|
||||
|
||||
func TestQuery(t *testing.T) {
|
||||
for driver := range supportedDialects {
|
||||
t.Run(driver, func(t *testing.T) {
|
||||
|
@ -53,7 +59,7 @@ func TestQuery(t *testing.T) {
|
|||
for _, variation := range variations {
|
||||
t.Run(variation.desc, func(t *testing.T) {
|
||||
t.Run("using slice of structs", func(t *testing.T) {
|
||||
err := createTable(driver)
|
||||
err := createTables(driver)
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!, reason:", err.Error())
|
||||
}
|
||||
|
@ -123,7 +129,7 @@ func TestQuery(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("using slice of pointers to structs", func(t *testing.T) {
|
||||
err := createTable(driver)
|
||||
err := createTables(driver)
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!, reason:", err.Error())
|
||||
}
|
||||
|
@ -190,12 +196,66 @@ func TestQuery(t *testing.T) {
|
|||
assert.Equal(t, "Bia Garcia", users[1].Name)
|
||||
assert.Equal(t, "BR", users[1].Address.Country)
|
||||
})
|
||||
|
||||
t.Run("should query joined tables correctly", func(t *testing.T) {
|
||||
db := connectDB(t, driver)
|
||||
defer db.Close()
|
||||
|
||||
// This test only makes sense with no query prefix
|
||||
if variation.queryPrefix != "" {
|
||||
return
|
||||
}
|
||||
|
||||
_, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('João Ribeiro', 0, '{"country":"US"}')`)
|
||||
assert.Equal(t, nil, err)
|
||||
var joaoID uint
|
||||
db.QueryRow(`SELECT id FROM users WHERE name = 'João Ribeiro'`).Scan(&joaoID)
|
||||
|
||||
_, err = db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia Ribeiro', 0, '{"country":"BR"}')`)
|
||||
assert.Equal(t, nil, err)
|
||||
var biaID uint
|
||||
db.QueryRow(`SELECT id FROM users WHERE name = 'Bia Ribeiro'`).Scan(&biaID)
|
||||
|
||||
_, err = db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, biaID, `, 'Bia Post1')`))
|
||||
assert.Equal(t, nil, err)
|
||||
_, err = db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, biaID, `, 'Bia Post2')`))
|
||||
assert.Equal(t, nil, err)
|
||||
_, err = db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, joaoID, `, 'João Post1')`))
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
ctx := context.Background()
|
||||
c := newTestDB(db, driver, "users")
|
||||
var rows []*struct {
|
||||
User User `tablename:"u"`
|
||||
Post Post `tablename:"p"`
|
||||
}
|
||||
err = c.Query(ctx, &rows, fmt.Sprint(
|
||||
`FROM users u JOIN posts p ON p.user_id = u.id`,
|
||||
` WHERE u.name like `, c.dialect.Placeholder(0),
|
||||
` ORDER BY u.id, p.id`,
|
||||
), "% Ribeiro")
|
||||
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, 3, len(rows))
|
||||
|
||||
assert.Equal(t, joaoID, rows[0].User.ID)
|
||||
assert.Equal(t, "João Ribeiro", rows[0].User.Name)
|
||||
assert.Equal(t, "João Post1", rows[0].Post.Title)
|
||||
|
||||
assert.Equal(t, biaID, rows[1].User.ID)
|
||||
assert.Equal(t, "Bia Ribeiro", rows[1].User.Name)
|
||||
assert.Equal(t, "Bia Post1", rows[1].Post.Title)
|
||||
|
||||
assert.Equal(t, biaID, rows[2].User.ID)
|
||||
assert.Equal(t, "Bia Ribeiro", rows[2].User.Name)
|
||||
assert.Equal(t, "Bia Post2", rows[2].Post.Title)
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("testing error cases", func(t *testing.T) {
|
||||
err := createTable(driver)
|
||||
err := createTables(driver)
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!, reason:", err.Error())
|
||||
}
|
||||
|
@ -258,7 +318,7 @@ func TestQueryOne(t *testing.T) {
|
|||
},
|
||||
}
|
||||
for _, variation := range variations {
|
||||
err := createTable(driver)
|
||||
err := createTables(driver)
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!, reason:", err.Error())
|
||||
}
|
||||
|
@ -358,7 +418,7 @@ func TestInsert(t *testing.T) {
|
|||
for driver := range supportedDialects {
|
||||
t.Run(driver, func(t *testing.T) {
|
||||
t.Run("using slice of structs", func(t *testing.T) {
|
||||
err := createTable(driver)
|
||||
err := createTables(driver)
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!, reason:", err.Error())
|
||||
}
|
||||
|
@ -428,7 +488,7 @@ func TestInsert(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("testing error cases", func(t *testing.T) {
|
||||
err := createTable(driver)
|
||||
err := createTables(driver)
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!, reason:", err.Error())
|
||||
}
|
||||
|
@ -485,7 +545,7 @@ func TestInsert(t *testing.T) {
|
|||
func TestDelete(t *testing.T) {
|
||||
for driver := range supportedDialects {
|
||||
t.Run(driver, func(t *testing.T) {
|
||||
err := createTable(driver)
|
||||
err := createTables(driver)
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!, reason:", err.Error())
|
||||
}
|
||||
|
@ -628,7 +688,7 @@ func TestDelete(t *testing.T) {
|
|||
func TestUpdate(t *testing.T) {
|
||||
for driver := range supportedDialects {
|
||||
t.Run(driver, func(t *testing.T) {
|
||||
err := createTable(driver)
|
||||
err := createTables(driver)
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!, reason:", err.Error())
|
||||
}
|
||||
|
@ -813,7 +873,7 @@ func TestQueryChunks(t *testing.T) {
|
|||
for _, variation := range variations {
|
||||
t.Run(variation.desc, func(t *testing.T) {
|
||||
t.Run("should query a single row correctly", func(t *testing.T) {
|
||||
err := createTable(driver)
|
||||
err := createTables(driver)
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!, reason:", err.Error())
|
||||
}
|
||||
|
@ -853,7 +913,7 @@ func TestQueryChunks(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("should query one chunk correctly", func(t *testing.T) {
|
||||
err := createTable(driver)
|
||||
err := createTables(driver)
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!, reason:", err.Error())
|
||||
}
|
||||
|
@ -895,7 +955,7 @@ func TestQueryChunks(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("should query chunks of 1 correctly", func(t *testing.T) {
|
||||
err := createTable(driver)
|
||||
err := createTables(driver)
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!, reason:", err.Error())
|
||||
}
|
||||
|
@ -937,7 +997,7 @@ func TestQueryChunks(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("should load partially filled chunks correctly", func(t *testing.T) {
|
||||
err := createTable(driver)
|
||||
err := createTables(driver)
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!, reason:", err.Error())
|
||||
}
|
||||
|
@ -978,7 +1038,7 @@ func TestQueryChunks(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("should abort the first iteration when the callback returns an ErrAbortIteration", func(t *testing.T) {
|
||||
err := createTable(driver)
|
||||
err := createTables(driver)
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!, reason:", err.Error())
|
||||
}
|
||||
|
@ -1017,7 +1077,7 @@ func TestQueryChunks(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("should abort the last iteration when the callback returns an ErrAbortIteration", func(t *testing.T) {
|
||||
err := createTable(driver)
|
||||
err := createTables(driver)
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!, reason:", err.Error())
|
||||
}
|
||||
|
@ -1060,7 +1120,7 @@ func TestQueryChunks(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("should return error if the callback returns an error in the first iteration", func(t *testing.T) {
|
||||
err := createTable(driver)
|
||||
err := createTables(driver)
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!, reason:", err.Error())
|
||||
}
|
||||
|
@ -1099,7 +1159,7 @@ func TestQueryChunks(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("should return error if the callback returns an error in the last iteration", func(t *testing.T) {
|
||||
err := createTable(driver)
|
||||
err := createTables(driver)
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!, reason:", err.Error())
|
||||
}
|
||||
|
@ -1213,7 +1273,7 @@ func TestTransaction(t *testing.T) {
|
|||
for driver := range supportedDialects {
|
||||
t.Run(driver, func(t *testing.T) {
|
||||
t.Run("should query a single row correctly", func(t *testing.T) {
|
||||
err := createTable(driver)
|
||||
err := createTables(driver)
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!, reason:", err.Error())
|
||||
}
|
||||
|
@ -1240,7 +1300,7 @@ func TestTransaction(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("should rollback when there are errors", func(t *testing.T) {
|
||||
err := createTable(driver)
|
||||
err := createTables(driver)
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!, reason:", err.Error())
|
||||
}
|
||||
|
@ -1281,7 +1341,7 @@ func TestTransaction(t *testing.T) {
|
|||
|
||||
func TestScanRows(t *testing.T) {
|
||||
t.Run("should scan users correctly", func(t *testing.T) {
|
||||
err := createTable("sqlite3")
|
||||
err := createTables("sqlite3")
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!, reason:", err.Error())
|
||||
}
|
||||
|
@ -1310,7 +1370,7 @@ func TestScanRows(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("should ignore extra columns from query", func(t *testing.T) {
|
||||
err := createTable("sqlite3")
|
||||
err := createTables("sqlite3")
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!, reason:", err.Error())
|
||||
}
|
||||
|
@ -1342,7 +1402,7 @@ func TestScanRows(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("should report error for closed rows", func(t *testing.T) {
|
||||
err := createTable("sqlite3")
|
||||
err := createTables("sqlite3")
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!, reason:", err.Error())
|
||||
}
|
||||
|
@ -1363,7 +1423,7 @@ func TestScanRows(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("should report if record is not a pointer", func(t *testing.T) {
|
||||
err := createTable("sqlite3")
|
||||
err := createTables("sqlite3")
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!, reason:", err.Error())
|
||||
}
|
||||
|
@ -1382,7 +1442,7 @@ func TestScanRows(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("should report if record is not a pointer to struct", func(t *testing.T) {
|
||||
err := createTable("sqlite3")
|
||||
err := createTables("sqlite3")
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!, reason:", err.Error())
|
||||
}
|
||||
|
@ -1408,7 +1468,7 @@ var connectionString = map[string]string{
|
|||
"sqlserver": "sqlserver://sa:Sqls3rv3r@127.0.0.1:1433?databaseName=ksql",
|
||||
}
|
||||
|
||||
func createTable(driver string) error {
|
||||
func createTables(driver string) error {
|
||||
connStr := connectionString[driver]
|
||||
if connStr == "" {
|
||||
return fmt.Errorf("unsupported driver: '%s'", driver)
|
||||
|
@ -1456,6 +1516,38 @@ func createTable(driver string) error {
|
|||
return fmt.Errorf("failed to create new users table: %s", err.Error())
|
||||
}
|
||||
|
||||
db.Exec(`DROP TABLE posts`)
|
||||
|
||||
switch driver {
|
||||
case "sqlite3":
|
||||
_, err = db.Exec(`CREATE TABLE posts (
|
||||
id INTEGER PRIMARY KEY,
|
||||
user_id INTEGER,
|
||||
title TEXT
|
||||
)`)
|
||||
case "postgres":
|
||||
_, err = db.Exec(`CREATE TABLE posts (
|
||||
id serial PRIMARY KEY,
|
||||
user_id INT,
|
||||
title VARCHAR(50)
|
||||
)`)
|
||||
case "mysql":
|
||||
_, err = db.Exec(`CREATE TABLE posts (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
user_id INT,
|
||||
title VARCHAR(50)
|
||||
)`)
|
||||
case "sqlserver":
|
||||
_, err = db.Exec(`CREATE TABLE posts (
|
||||
id INT IDENTITY(1,1) PRIMARY KEY,
|
||||
user_id INT,
|
||||
title VARCHAR(50)
|
||||
)`)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create new users table: %s", err.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -8,49 +8,56 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type structInfo struct {
|
||||
byIndex map[int]*fieldInfo
|
||||
byName map[string]*fieldInfo
|
||||
// StructInfo stores metainformation of the struct
|
||||
// parser in order to help the ksql library to work
|
||||
// efectively and efficiently with reflection.
|
||||
type StructInfo struct {
|
||||
IsNestedStruct bool
|
||||
byIndex map[int]*FieldInfo
|
||||
byName map[string]*FieldInfo
|
||||
}
|
||||
|
||||
type fieldInfo struct {
|
||||
// FieldInfo contains reflection and tags
|
||||
// information regarding a specific field
|
||||
// of a struct.
|
||||
type FieldInfo struct {
|
||||
Name string
|
||||
Index int
|
||||
Valid bool
|
||||
SerializeAsJSON bool
|
||||
}
|
||||
|
||||
func (s structInfo) ByIndex(idx int) *fieldInfo {
|
||||
// ByIndex returns either the *FieldInfo of a valid
|
||||
// empty struct with Valid set to false
|
||||
func (s StructInfo) ByIndex(idx int) *FieldInfo {
|
||||
field, found := s.byIndex[idx]
|
||||
if !found {
|
||||
return &fieldInfo{}
|
||||
return &FieldInfo{}
|
||||
}
|
||||
return field
|
||||
}
|
||||
|
||||
func (s structInfo) ByName(name string) *fieldInfo {
|
||||
// ByName returns either the *FieldInfo of a valid
|
||||
// empty struct with Valid set to false
|
||||
func (s StructInfo) ByName(name string) *FieldInfo {
|
||||
field, found := s.byName[name]
|
||||
if !found {
|
||||
return &fieldInfo{}
|
||||
return &FieldInfo{}
|
||||
}
|
||||
return field
|
||||
}
|
||||
|
||||
func (s structInfo) Add(field fieldInfo) {
|
||||
func (s StructInfo) add(field FieldInfo) {
|
||||
field.Valid = true
|
||||
s.byIndex[field.Index] = &field
|
||||
s.byName[field.Name] = &field
|
||||
}
|
||||
|
||||
func (s structInfo) Fields() map[int]*fieldInfo {
|
||||
return s.byIndex
|
||||
}
|
||||
|
||||
// This cache is kept as a pkg variable
|
||||
// because the total number of types on a program
|
||||
// should be finite. So keeping a single cache here
|
||||
// works fine.
|
||||
var tagInfoCache = map[reflect.Type]structInfo{}
|
||||
var tagInfoCache = map[reflect.Type]StructInfo{}
|
||||
|
||||
// GetTagInfo efficiently returns the type information
|
||||
// using a global private cache
|
||||
|
@ -58,16 +65,17 @@ var tagInfoCache = map[reflect.Type]structInfo{}
|
|||
// In the future we might move this cache inside
|
||||
// a struct, but for now this accessor is the one
|
||||
// we are using
|
||||
func GetTagInfo(key reflect.Type) structInfo {
|
||||
func GetTagInfo(key reflect.Type) StructInfo {
|
||||
return getCachedTagInfo(tagInfoCache, key)
|
||||
}
|
||||
|
||||
func getCachedTagInfo(tagInfoCache map[reflect.Type]structInfo, key reflect.Type) structInfo {
|
||||
info, found := tagInfoCache[key]
|
||||
if !found {
|
||||
info = getTagNames(key)
|
||||
tagInfoCache[key] = info
|
||||
func getCachedTagInfo(tagInfoCache map[reflect.Type]StructInfo, key reflect.Type) StructInfo {
|
||||
if info, found := tagInfoCache[key]; found {
|
||||
return info
|
||||
}
|
||||
|
||||
info := getTagNames(key)
|
||||
tagInfoCache[key] = info
|
||||
return info
|
||||
}
|
||||
|
||||
|
@ -291,10 +299,10 @@ func FillSliceWith(entities interface{}, dbRows []map[string]interface{}) error
|
|||
//
|
||||
// This should save several calls to `Field(i).Tag.Get("foo")`
|
||||
// which improves performance by a lot.
|
||||
func getTagNames(t reflect.Type) structInfo {
|
||||
info := structInfo{
|
||||
byIndex: map[int]*fieldInfo{},
|
||||
byName: map[string]*fieldInfo{},
|
||||
func getTagNames(t reflect.Type) StructInfo {
|
||||
info := StructInfo{
|
||||
byIndex: map[int]*FieldInfo{},
|
||||
byName: map[string]*FieldInfo{},
|
||||
}
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
name := t.Field(i).Tag.Get("ksql")
|
||||
|
@ -309,13 +317,36 @@ func getTagNames(t reflect.Type) structInfo {
|
|||
serializeAsJSON = tags[1] == "json"
|
||||
}
|
||||
|
||||
info.Add(fieldInfo{
|
||||
info.add(FieldInfo{
|
||||
Name: name,
|
||||
Index: i,
|
||||
SerializeAsJSON: serializeAsJSON,
|
||||
})
|
||||
}
|
||||
|
||||
// If there were `ksql` tags present, then we are finished:
|
||||
if len(info.byIndex) > 0 {
|
||||
return info
|
||||
}
|
||||
|
||||
// If there are no `ksql` tags in the struct, lets assume
|
||||
// it is a struct tagged with `tablename` for allowing JOINs
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
name := t.Field(i).Tag.Get("tablename")
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
info.add(FieldInfo{
|
||||
Name: name,
|
||||
Index: i,
|
||||
})
|
||||
}
|
||||
|
||||
if len(info.byIndex) > 0 {
|
||||
info.IsNestedStruct = true
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue