diff --git a/kbuilder/insert.go b/kbuilder/insert.go index d65c483..7b54dae 100644 --- a/kbuilder/insert.go +++ b/kbuilder/insert.go @@ -73,7 +73,10 @@ func (i Insert) BuildQuery(dialect ksql.Dialect) (sqlQuery string, params []inte return "", nil, fmt.Errorf("expected Data attr to be a struct or slice of structs but got: %v", t) } - info := kstructs.GetTagInfo(t) + info, err := kstructs.GetTagInfo(t) + if err != nil { + return "", nil, err + } b.WriteString(" (") var escapedNames []string diff --git a/kbuilder/query.go b/kbuilder/query.go index 2fc4d28..f222cea 100644 --- a/kbuilder/query.go +++ b/kbuilder/query.go @@ -204,7 +204,10 @@ func buildSelectQuery(obj interface{}, dialect ksql.Dialect) (string, error) { return query, nil } - info := kstructs.GetTagInfo(t) + info, err := kstructs.GetTagInfo(t) + if err != nil { + return "", err + } var escapedNames []string for i := 0; i < info.NumFields(); i++ { diff --git a/ksql.go b/ksql.go index 9485725..6658107 100644 --- a/ksql.go +++ b/ksql.go @@ -185,7 +185,10 @@ func (c DB) Query( slice = slice.Slice(0, 0) } - info := kstructs.GetTagInfo(structType) + info, err := kstructs.GetTagInfo(structType) + if err != nil { + return err + } firstToken := strings.ToUpper(getFirstToken(query)) if info.IsNestedStruct && firstToken == "SELECT" { @@ -272,7 +275,10 @@ func (c DB) QueryOne( return fmt.Errorf("ksql: expected to receive a pointer to struct, but got: %T", record) } - info := kstructs.GetTagInfo(tStruct) + info, err := kstructs.GetTagInfo(tStruct) + if err != nil { + return err + } firstToken := strings.ToUpper(getFirstToken(query)) if info.IsNestedStruct && firstToken == "SELECT" { @@ -342,7 +348,10 @@ func (c DB) QueryChunks( return err } - info := kstructs.GetTagInfo(structType) + info, err := kstructs.GetTagInfo(structType) + if err != nil { + return err + } firstToken := strings.ToUpper(getFirstToken(parser.Query)) if info.IsNestedStruct && firstToken == "SELECT" { @@ -445,7 +454,10 @@ func (c DB) Insert( return fmt.Errorf("ksql: expected a valid pointer to struct as argument but received a nil pointer: %v", record) } - info := kstructs.GetTagInfo(t.Elem()) + info, err := kstructs.GetTagInfo(t.Elem()) + if err != nil { + return err + } query, params, scanValues, err := buildInsertQuery(c.dialect, table.name, t, v, info, record, table.idColumns...) if err != nil { @@ -676,7 +688,10 @@ func (c DB) Update( } tStruct = t.Elem() } - info := kstructs.GetTagInfo(tStruct) + info, err := kstructs.GetTagInfo(tStruct) + if err != nil { + return err + } query, params, err := buildUpdateQuery(c.dialect, table.name, info, record, table.idColumns...) if err != nil { @@ -940,14 +955,20 @@ func scanRowsFromType( return fmt.Errorf("ksql: expected record to be a pointer to struct, but got: %T", record) } - info := kstructs.GetTagInfo(t) + info, err := kstructs.GetTagInfo(t) + if err != nil { + return err + } var scanArgs []interface{} if info.IsNestedStruct { // This version is positional meaning that it expect the arguments // to follow an specific order. It's ok because we don't allow the // user to type the "SELECT" part of the query for nested kstructs. - scanArgs = getScanArgsForNestedStructs(dialect, rows, t, v, info) + scanArgs, err = getScanArgsForNestedStructs(dialect, rows, t, v, info) + if err != nil { + return err + } } else { names, err := rows.Columns() if err != nil { @@ -961,11 +982,15 @@ func scanRowsFromType( return rows.Scan(scanArgs...) } -func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v reflect.Value, info kstructs.StructInfo) []interface{} { +func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v reflect.Value, info kstructs.StructInfo) ([]interface{}, error) { scanArgs := []interface{}{} for i := 0; i < v.NumField(); i++ { // TODO(vingarcia00): Handle case where type is pointer - nestedStructInfo := kstructs.GetTagInfo(t.Field(i).Type) + nestedStructInfo, err := kstructs.GetTagInfo(t.Field(i).Type) + if err != nil { + return nil, err + } + nestedStructValue := v.Field(i) for j := 0; j < nestedStructValue.NumField(); j++ { fieldInfo := nestedStructInfo.ByIndex(j) @@ -985,7 +1010,7 @@ func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v r } } - return scanArgs + return scanArgs, nil } func getScanArgsFromNames(dialect Dialect, names []string, v reflect.Value, info kstructs.StructInfo) []interface{} { @@ -1127,7 +1152,11 @@ func buildSelectQueryForNestedStructs( ) } - nestedStructInfo := kstructs.GetTagInfo(nestedStructType) + nestedStructInfo, err := kstructs.GetTagInfo(nestedStructType) + if err != nil { + return "", err + } + for j := 0; j < structType.Field(i).Type.NumField(); j++ { fields = append( fields, diff --git a/kstructs/structs.go b/kstructs/structs.go index 968de3a..ea87d4e 100644 --- a/kstructs/structs.go +++ b/kstructs/structs.go @@ -68,18 +68,22 @@ var tagInfoCache = map[reflect.Type]StructInfo{} // In the future we might move this cache inside // a struct, but for now this accessor is the one // we are using -func GetTagInfo(key reflect.Type) StructInfo { +func GetTagInfo(key reflect.Type) (StructInfo, error) { return getCachedTagInfo(tagInfoCache, key) } -func getCachedTagInfo(tagInfoCache map[reflect.Type]StructInfo, key reflect.Type) StructInfo { +func getCachedTagInfo(tagInfoCache map[reflect.Type]StructInfo, key reflect.Type) (StructInfo, error) { if info, found := tagInfoCache[key]; found { - return info + return info, nil + } + + info, err := getTagNames(key) + if err != nil { + return StructInfo{}, err } - info := getTagNames(key) tagInfoCache[key] = info - return info + return info, nil } // StructToMap converts any struct type to a map based on @@ -103,11 +107,15 @@ func StructToMap(obj interface{}) (map[string]interface{}, error) { return nil, fmt.Errorf("input must be a struct or struct pointer") } - info := getCachedTagInfo(tagInfoCache, t) + info, err := getCachedTagInfo(tagInfoCache, t) + if err != nil { + return nil, err + } m := map[string]interface{}{} for i := 0; i < v.NumField(); i++ { - if !info.ByIndex(i).Valid { + fieldInfo := info.ByIndex(i) + if !fieldInfo.Valid { continue } @@ -121,7 +129,7 @@ func StructToMap(obj interface{}) (map[string]interface{}, error) { field = field.Elem() } - m[info.ByIndex(i).Name] = field.Interface() + m[fieldInfo.Name] = field.Interface() } return m, nil @@ -213,7 +221,7 @@ func (p PtrConverter) Convert(destType reflect.Type) (reflect.Value, error) { // // This should save several calls to `Field(i).Tag.Get("foo")` // which improves performance by a lot. -func getTagNames(t reflect.Type) StructInfo { +func getTagNames(t reflect.Type) (StructInfo, error) { info := StructInfo{ byIndex: map[int]*FieldInfo{}, byName: map[string]*FieldInfo{}, @@ -231,6 +239,13 @@ func getTagNames(t reflect.Type) StructInfo { serializeAsJSON = tags[1] == "json" } + if _, found := info.byName[name]; found { + return StructInfo{}, fmt.Errorf( + "struct contains multiple attributes with the same ksql tag name: '%s'", + name, + ) + } + info.add(FieldInfo{ Name: name, Index: i, @@ -240,7 +255,7 @@ func getTagNames(t reflect.Type) StructInfo { // If there were `ksql` tags present, then we are finished: if len(info.byIndex) > 0 { - return info + return info, nil } // If there are no `ksql` tags in the struct, lets assume @@ -261,7 +276,7 @@ func getTagNames(t reflect.Type) StructInfo { info.IsNestedStruct = true } - return info + return info, nil } // DecodeAsSliceOfStructs makes several checks diff --git a/kstructs/structs_test.go b/kstructs/structs_test.go index b9cb98f..388b6b9 100644 --- a/kstructs/structs_test.go +++ b/kstructs/structs_test.go @@ -85,6 +85,20 @@ func TestStructToMap(t *testing.T) { "age_attr": 42, }, m) }) + + t.Run("should return error for duplicated ksql tag names", func(t *testing.T) { + _, err := StructToMap(struct { + Name string `ksql:"name_attr"` + DuplicatedName string `ksql:"name_attr"` + Age int `ksql:"age_attr"` + }{ + Name: "fake-name", + Age: 42, + DuplicatedName: "fake-duplicated-name", + }) + + assert.NotEqual(t, nil, err) + }) } func TestFillStructWith(t *testing.T) { diff --git a/kstructs/testhelpers.go b/kstructs/testhelpers.go index caa7bb7..6955a9e 100644 --- a/kstructs/testhelpers.go +++ b/kstructs/testhelpers.go @@ -34,7 +34,11 @@ func FillStructWith(record interface{}, dbRow map[string]interface{}) error { ) } - info := GetTagInfo(t) + info, err := GetTagInfo(t) + if err != nil { + return err + } + for colName, rawSrc := range dbRow { fieldInfo := info.ByName(colName) if !fieldInfo.Valid {