mirror of https://github.com/VinGarcia/ksql.git
parent
e970a3546a
commit
82a43fda87
|
@ -73,7 +73,10 @@ func (i Insert) BuildQuery(dialect ksql.Dialect) (sqlQuery string, params []inte
|
|||
return "", nil, fmt.Errorf("expected Data attr to be a struct or slice of structs but got: %v", t)
|
||||
}
|
||||
|
||||
info := kstructs.GetTagInfo(t)
|
||||
info, err := kstructs.GetTagInfo(t)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
b.WriteString(" (")
|
||||
var escapedNames []string
|
||||
|
|
|
@ -204,7 +204,10 @@ func buildSelectQuery(obj interface{}, dialect ksql.Dialect) (string, error) {
|
|||
return query, nil
|
||||
}
|
||||
|
||||
info := kstructs.GetTagInfo(t)
|
||||
info, err := kstructs.GetTagInfo(t)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var escapedNames []string
|
||||
for i := 0; i < info.NumFields(); i++ {
|
||||
|
|
51
ksql.go
51
ksql.go
|
@ -185,7 +185,10 @@ func (c DB) Query(
|
|||
slice = slice.Slice(0, 0)
|
||||
}
|
||||
|
||||
info := kstructs.GetTagInfo(structType)
|
||||
info, err := kstructs.GetTagInfo(structType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
firstToken := strings.ToUpper(getFirstToken(query))
|
||||
if info.IsNestedStruct && firstToken == "SELECT" {
|
||||
|
@ -272,7 +275,10 @@ func (c DB) QueryOne(
|
|||
return fmt.Errorf("ksql: expected to receive a pointer to struct, but got: %T", record)
|
||||
}
|
||||
|
||||
info := kstructs.GetTagInfo(tStruct)
|
||||
info, err := kstructs.GetTagInfo(tStruct)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
firstToken := strings.ToUpper(getFirstToken(query))
|
||||
if info.IsNestedStruct && firstToken == "SELECT" {
|
||||
|
@ -342,7 +348,10 @@ func (c DB) QueryChunks(
|
|||
return err
|
||||
}
|
||||
|
||||
info := kstructs.GetTagInfo(structType)
|
||||
info, err := kstructs.GetTagInfo(structType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
firstToken := strings.ToUpper(getFirstToken(parser.Query))
|
||||
if info.IsNestedStruct && firstToken == "SELECT" {
|
||||
|
@ -445,7 +454,10 @@ func (c DB) Insert(
|
|||
return fmt.Errorf("ksql: expected a valid pointer to struct as argument but received a nil pointer: %v", record)
|
||||
}
|
||||
|
||||
info := kstructs.GetTagInfo(t.Elem())
|
||||
info, err := kstructs.GetTagInfo(t.Elem())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query, params, scanValues, err := buildInsertQuery(c.dialect, table.name, t, v, info, record, table.idColumns...)
|
||||
if err != nil {
|
||||
|
@ -676,7 +688,10 @@ func (c DB) Update(
|
|||
}
|
||||
tStruct = t.Elem()
|
||||
}
|
||||
info := kstructs.GetTagInfo(tStruct)
|
||||
info, err := kstructs.GetTagInfo(tStruct)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query, params, err := buildUpdateQuery(c.dialect, table.name, info, record, table.idColumns...)
|
||||
if err != nil {
|
||||
|
@ -940,14 +955,20 @@ func scanRowsFromType(
|
|||
return fmt.Errorf("ksql: expected record to be a pointer to struct, but got: %T", record)
|
||||
}
|
||||
|
||||
info := kstructs.GetTagInfo(t)
|
||||
info, err := kstructs.GetTagInfo(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var scanArgs []interface{}
|
||||
if info.IsNestedStruct {
|
||||
// This version is positional meaning that it expect the arguments
|
||||
// to follow an specific order. It's ok because we don't allow the
|
||||
// user to type the "SELECT" part of the query for nested kstructs.
|
||||
scanArgs = getScanArgsForNestedStructs(dialect, rows, t, v, info)
|
||||
scanArgs, err = getScanArgsForNestedStructs(dialect, rows, t, v, info)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
names, err := rows.Columns()
|
||||
if err != nil {
|
||||
|
@ -961,11 +982,15 @@ func scanRowsFromType(
|
|||
return rows.Scan(scanArgs...)
|
||||
}
|
||||
|
||||
func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v reflect.Value, info kstructs.StructInfo) []interface{} {
|
||||
func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v reflect.Value, info kstructs.StructInfo) ([]interface{}, error) {
|
||||
scanArgs := []interface{}{}
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
// TODO(vingarcia00): Handle case where type is pointer
|
||||
nestedStructInfo := kstructs.GetTagInfo(t.Field(i).Type)
|
||||
nestedStructInfo, err := kstructs.GetTagInfo(t.Field(i).Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nestedStructValue := v.Field(i)
|
||||
for j := 0; j < nestedStructValue.NumField(); j++ {
|
||||
fieldInfo := nestedStructInfo.ByIndex(j)
|
||||
|
@ -985,7 +1010,7 @@ func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v r
|
|||
}
|
||||
}
|
||||
|
||||
return scanArgs
|
||||
return scanArgs, nil
|
||||
}
|
||||
|
||||
func getScanArgsFromNames(dialect Dialect, names []string, v reflect.Value, info kstructs.StructInfo) []interface{} {
|
||||
|
@ -1127,7 +1152,11 @@ func buildSelectQueryForNestedStructs(
|
|||
)
|
||||
}
|
||||
|
||||
nestedStructInfo := kstructs.GetTagInfo(nestedStructType)
|
||||
nestedStructInfo, err := kstructs.GetTagInfo(nestedStructType)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for j := 0; j < structType.Field(i).Type.NumField(); j++ {
|
||||
fields = append(
|
||||
fields,
|
||||
|
|
|
@ -68,18 +68,22 @@ var tagInfoCache = map[reflect.Type]StructInfo{}
|
|||
// In the future we might move this cache inside
|
||||
// a struct, but for now this accessor is the one
|
||||
// we are using
|
||||
func GetTagInfo(key reflect.Type) StructInfo {
|
||||
func GetTagInfo(key reflect.Type) (StructInfo, error) {
|
||||
return getCachedTagInfo(tagInfoCache, key)
|
||||
}
|
||||
|
||||
func getCachedTagInfo(tagInfoCache map[reflect.Type]StructInfo, key reflect.Type) StructInfo {
|
||||
func getCachedTagInfo(tagInfoCache map[reflect.Type]StructInfo, key reflect.Type) (StructInfo, error) {
|
||||
if info, found := tagInfoCache[key]; found {
|
||||
return info
|
||||
return info, nil
|
||||
}
|
||||
|
||||
info, err := getTagNames(key)
|
||||
if err != nil {
|
||||
return StructInfo{}, err
|
||||
}
|
||||
|
||||
info := getTagNames(key)
|
||||
tagInfoCache[key] = info
|
||||
return info
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// StructToMap converts any struct type to a map based on
|
||||
|
@ -103,11 +107,15 @@ func StructToMap(obj interface{}) (map[string]interface{}, error) {
|
|||
return nil, fmt.Errorf("input must be a struct or struct pointer")
|
||||
}
|
||||
|
||||
info := getCachedTagInfo(tagInfoCache, t)
|
||||
info, err := getCachedTagInfo(tagInfoCache, t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := map[string]interface{}{}
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
if !info.ByIndex(i).Valid {
|
||||
fieldInfo := info.ByIndex(i)
|
||||
if !fieldInfo.Valid {
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -121,7 +129,7 @@ func StructToMap(obj interface{}) (map[string]interface{}, error) {
|
|||
field = field.Elem()
|
||||
}
|
||||
|
||||
m[info.ByIndex(i).Name] = field.Interface()
|
||||
m[fieldInfo.Name] = field.Interface()
|
||||
}
|
||||
|
||||
return m, nil
|
||||
|
@ -213,7 +221,7 @@ func (p PtrConverter) Convert(destType reflect.Type) (reflect.Value, error) {
|
|||
//
|
||||
// This should save several calls to `Field(i).Tag.Get("foo")`
|
||||
// which improves performance by a lot.
|
||||
func getTagNames(t reflect.Type) StructInfo {
|
||||
func getTagNames(t reflect.Type) (StructInfo, error) {
|
||||
info := StructInfo{
|
||||
byIndex: map[int]*FieldInfo{},
|
||||
byName: map[string]*FieldInfo{},
|
||||
|
@ -231,6 +239,13 @@ func getTagNames(t reflect.Type) StructInfo {
|
|||
serializeAsJSON = tags[1] == "json"
|
||||
}
|
||||
|
||||
if _, found := info.byName[name]; found {
|
||||
return StructInfo{}, fmt.Errorf(
|
||||
"struct contains multiple attributes with the same ksql tag name: '%s'",
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
info.add(FieldInfo{
|
||||
Name: name,
|
||||
Index: i,
|
||||
|
@ -240,7 +255,7 @@ func getTagNames(t reflect.Type) StructInfo {
|
|||
|
||||
// If there were `ksql` tags present, then we are finished:
|
||||
if len(info.byIndex) > 0 {
|
||||
return info
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// If there are no `ksql` tags in the struct, lets assume
|
||||
|
@ -261,7 +276,7 @@ func getTagNames(t reflect.Type) StructInfo {
|
|||
info.IsNestedStruct = true
|
||||
}
|
||||
|
||||
return info
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// DecodeAsSliceOfStructs makes several checks
|
||||
|
|
|
@ -85,6 +85,20 @@ func TestStructToMap(t *testing.T) {
|
|||
"age_attr": 42,
|
||||
}, m)
|
||||
})
|
||||
|
||||
t.Run("should return error for duplicated ksql tag names", func(t *testing.T) {
|
||||
_, err := StructToMap(struct {
|
||||
Name string `ksql:"name_attr"`
|
||||
DuplicatedName string `ksql:"name_attr"`
|
||||
Age int `ksql:"age_attr"`
|
||||
}{
|
||||
Name: "fake-name",
|
||||
Age: 42,
|
||||
DuplicatedName: "fake-duplicated-name",
|
||||
})
|
||||
|
||||
assert.NotEqual(t, nil, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestFillStructWith(t *testing.T) {
|
||||
|
|
|
@ -34,7 +34,11 @@ func FillStructWith(record interface{}, dbRow map[string]interface{}) error {
|
|||
)
|
||||
}
|
||||
|
||||
info := GetTagInfo(t)
|
||||
info, err := GetTagInfo(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for colName, rawSrc := range dbRow {
|
||||
fieldInfo := info.ByName(colName)
|
||||
if !fieldInfo.Valid {
|
||||
|
|
Loading…
Reference in New Issue