Add error check for structs with duplicated tag names

Closes #6
pull/10/head
Vinícius Garcia 2021-12-22 19:36:37 -03:00
parent e970a3546a
commit 82a43fda87
6 changed files with 93 additions and 25 deletions

View File

@ -73,7 +73,10 @@ func (i Insert) BuildQuery(dialect ksql.Dialect) (sqlQuery string, params []inte
return "", nil, fmt.Errorf("expected Data attr to be a struct or slice of structs but got: %v", t)
}
info := kstructs.GetTagInfo(t)
info, err := kstructs.GetTagInfo(t)
if err != nil {
return "", nil, err
}
b.WriteString(" (")
var escapedNames []string

View File

@ -204,7 +204,10 @@ func buildSelectQuery(obj interface{}, dialect ksql.Dialect) (string, error) {
return query, nil
}
info := kstructs.GetTagInfo(t)
info, err := kstructs.GetTagInfo(t)
if err != nil {
return "", err
}
var escapedNames []string
for i := 0; i < info.NumFields(); i++ {

51
ksql.go
View File

@ -185,7 +185,10 @@ func (c DB) Query(
slice = slice.Slice(0, 0)
}
info := kstructs.GetTagInfo(structType)
info, err := kstructs.GetTagInfo(structType)
if err != nil {
return err
}
firstToken := strings.ToUpper(getFirstToken(query))
if info.IsNestedStruct && firstToken == "SELECT" {
@ -272,7 +275,10 @@ func (c DB) QueryOne(
return fmt.Errorf("ksql: expected to receive a pointer to struct, but got: %T", record)
}
info := kstructs.GetTagInfo(tStruct)
info, err := kstructs.GetTagInfo(tStruct)
if err != nil {
return err
}
firstToken := strings.ToUpper(getFirstToken(query))
if info.IsNestedStruct && firstToken == "SELECT" {
@ -342,7 +348,10 @@ func (c DB) QueryChunks(
return err
}
info := kstructs.GetTagInfo(structType)
info, err := kstructs.GetTagInfo(structType)
if err != nil {
return err
}
firstToken := strings.ToUpper(getFirstToken(parser.Query))
if info.IsNestedStruct && firstToken == "SELECT" {
@ -445,7 +454,10 @@ func (c DB) Insert(
return fmt.Errorf("ksql: expected a valid pointer to struct as argument but received a nil pointer: %v", record)
}
info := kstructs.GetTagInfo(t.Elem())
info, err := kstructs.GetTagInfo(t.Elem())
if err != nil {
return err
}
query, params, scanValues, err := buildInsertQuery(c.dialect, table.name, t, v, info, record, table.idColumns...)
if err != nil {
@ -676,7 +688,10 @@ func (c DB) Update(
}
tStruct = t.Elem()
}
info := kstructs.GetTagInfo(tStruct)
info, err := kstructs.GetTagInfo(tStruct)
if err != nil {
return err
}
query, params, err := buildUpdateQuery(c.dialect, table.name, info, record, table.idColumns...)
if err != nil {
@ -940,14 +955,20 @@ func scanRowsFromType(
return fmt.Errorf("ksql: expected record to be a pointer to struct, but got: %T", record)
}
info := kstructs.GetTagInfo(t)
info, err := kstructs.GetTagInfo(t)
if err != nil {
return err
}
var scanArgs []interface{}
if info.IsNestedStruct {
// This version is positional meaning that it expect the arguments
// to follow an specific order. It's ok because we don't allow the
// user to type the "SELECT" part of the query for nested kstructs.
scanArgs = getScanArgsForNestedStructs(dialect, rows, t, v, info)
scanArgs, err = getScanArgsForNestedStructs(dialect, rows, t, v, info)
if err != nil {
return err
}
} else {
names, err := rows.Columns()
if err != nil {
@ -961,11 +982,15 @@ func scanRowsFromType(
return rows.Scan(scanArgs...)
}
func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v reflect.Value, info kstructs.StructInfo) []interface{} {
func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v reflect.Value, info kstructs.StructInfo) ([]interface{}, error) {
scanArgs := []interface{}{}
for i := 0; i < v.NumField(); i++ {
// TODO(vingarcia00): Handle case where type is pointer
nestedStructInfo := kstructs.GetTagInfo(t.Field(i).Type)
nestedStructInfo, err := kstructs.GetTagInfo(t.Field(i).Type)
if err != nil {
return nil, err
}
nestedStructValue := v.Field(i)
for j := 0; j < nestedStructValue.NumField(); j++ {
fieldInfo := nestedStructInfo.ByIndex(j)
@ -985,7 +1010,7 @@ func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v r
}
}
return scanArgs
return scanArgs, nil
}
func getScanArgsFromNames(dialect Dialect, names []string, v reflect.Value, info kstructs.StructInfo) []interface{} {
@ -1127,7 +1152,11 @@ func buildSelectQueryForNestedStructs(
)
}
nestedStructInfo := kstructs.GetTagInfo(nestedStructType)
nestedStructInfo, err := kstructs.GetTagInfo(nestedStructType)
if err != nil {
return "", err
}
for j := 0; j < structType.Field(i).Type.NumField(); j++ {
fields = append(
fields,

View File

@ -68,18 +68,22 @@ var tagInfoCache = map[reflect.Type]StructInfo{}
// In the future we might move this cache inside
// a struct, but for now this accessor is the one
// we are using
func GetTagInfo(key reflect.Type) StructInfo {
func GetTagInfo(key reflect.Type) (StructInfo, error) {
return getCachedTagInfo(tagInfoCache, key)
}
func getCachedTagInfo(tagInfoCache map[reflect.Type]StructInfo, key reflect.Type) StructInfo {
func getCachedTagInfo(tagInfoCache map[reflect.Type]StructInfo, key reflect.Type) (StructInfo, error) {
if info, found := tagInfoCache[key]; found {
return info
return info, nil
}
info, err := getTagNames(key)
if err != nil {
return StructInfo{}, err
}
info := getTagNames(key)
tagInfoCache[key] = info
return info
return info, nil
}
// StructToMap converts any struct type to a map based on
@ -103,11 +107,15 @@ func StructToMap(obj interface{}) (map[string]interface{}, error) {
return nil, fmt.Errorf("input must be a struct or struct pointer")
}
info := getCachedTagInfo(tagInfoCache, t)
info, err := getCachedTagInfo(tagInfoCache, t)
if err != nil {
return nil, err
}
m := map[string]interface{}{}
for i := 0; i < v.NumField(); i++ {
if !info.ByIndex(i).Valid {
fieldInfo := info.ByIndex(i)
if !fieldInfo.Valid {
continue
}
@ -121,7 +129,7 @@ func StructToMap(obj interface{}) (map[string]interface{}, error) {
field = field.Elem()
}
m[info.ByIndex(i).Name] = field.Interface()
m[fieldInfo.Name] = field.Interface()
}
return m, nil
@ -213,7 +221,7 @@ func (p PtrConverter) Convert(destType reflect.Type) (reflect.Value, error) {
//
// This should save several calls to `Field(i).Tag.Get("foo")`
// which improves performance by a lot.
func getTagNames(t reflect.Type) StructInfo {
func getTagNames(t reflect.Type) (StructInfo, error) {
info := StructInfo{
byIndex: map[int]*FieldInfo{},
byName: map[string]*FieldInfo{},
@ -231,6 +239,13 @@ func getTagNames(t reflect.Type) StructInfo {
serializeAsJSON = tags[1] == "json"
}
if _, found := info.byName[name]; found {
return StructInfo{}, fmt.Errorf(
"struct contains multiple attributes with the same ksql tag name: '%s'",
name,
)
}
info.add(FieldInfo{
Name: name,
Index: i,
@ -240,7 +255,7 @@ func getTagNames(t reflect.Type) StructInfo {
// If there were `ksql` tags present, then we are finished:
if len(info.byIndex) > 0 {
return info
return info, nil
}
// If there are no `ksql` tags in the struct, lets assume
@ -261,7 +276,7 @@ func getTagNames(t reflect.Type) StructInfo {
info.IsNestedStruct = true
}
return info
return info, nil
}
// DecodeAsSliceOfStructs makes several checks

View File

@ -85,6 +85,20 @@ func TestStructToMap(t *testing.T) {
"age_attr": 42,
}, m)
})
t.Run("should return error for duplicated ksql tag names", func(t *testing.T) {
_, err := StructToMap(struct {
Name string `ksql:"name_attr"`
DuplicatedName string `ksql:"name_attr"`
Age int `ksql:"age_attr"`
}{
Name: "fake-name",
Age: 42,
DuplicatedName: "fake-duplicated-name",
})
assert.NotEqual(t, nil, err)
})
}
func TestFillStructWith(t *testing.T) {

View File

@ -34,7 +34,11 @@ func FillStructWith(record interface{}, dbRow map[string]interface{}) error {
)
}
info := GetTagInfo(t)
info, err := GetTagInfo(t)
if err != nil {
return err
}
for colName, rawSrc := range dbRow {
fieldInfo := info.ByName(colName)
if !fieldInfo.Valid {