Add error check for preventing reflection panics in nested structs

pull/2/head
Vinícius Garcia 2021-05-23 12:25:35 -03:00
parent 6bd61346d9
commit 9e4583c3f8
2 changed files with 55 additions and 1 deletions

10
ksql.go
View File

@ -1005,7 +1005,15 @@ func buildSelectQueryForNestedStructs(
var fields []string var fields []string
for i := 0; i < structType.NumField(); i++ { for i := 0; i < structType.NumField(); i++ {
nestedStructName := info.ByIndex(i).Name nestedStructName := info.ByIndex(i).Name
nestedStructInfo := structs.GetTagInfo(structType.Field(i).Type) nestedStructType := structType.Field(i).Type
if nestedStructType.Kind() != reflect.Struct {
return "", fmt.Errorf(
"expected nested struct with `tablename:\"%s\"` to be a kind of Struct, but got %v",
nestedStructName, nestedStructType,
)
}
nestedStructInfo := structs.GetTagInfo(nestedStructType)
for j := 0; j < structType.Field(i).Type.NumField(); j++ { for j := 0; j < structType.Field(i).Type.NumField(); j++ {
fields = append( fields = append(
fields, fields,

View File

@ -351,6 +351,52 @@ func TestQuery(t *testing.T) {
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
}) })
}) })
t.Run("should report error for nested structs with invalid types", func(t *testing.T) {
t.Run("int", func(t *testing.T) {
db := connectDB(t, driver)
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
var rows []struct {
Foo int `tablename:"foo"`
}
err := c.Query(ctx, &rows, fmt.Sprint(
`FROM users u JOIN posts p ON p.user_id = u.id`,
` WHERE u.name like `, c.dialect.Placeholder(0),
` ORDER BY u.id, p.id`,
), "% Ribeiro")
assert.NotEqual(t, nil, err)
msg := err.Error()
for _, str := range []string{"foo", "int"} {
assert.Equal(t, true, strings.Contains(msg, str), fmt.Sprintf("missing expected substr '%s' in error message: '%s'", str, msg))
}
})
t.Run("*struct", func(t *testing.T) {
db := connectDB(t, driver)
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
var rows []struct {
Foo *User `tablename:"foo"`
}
err := c.Query(ctx, &rows, fmt.Sprint(
`FROM users u JOIN posts p ON p.user_id = u.id`,
` WHERE u.name like `, c.dialect.Placeholder(0),
` ORDER BY u.id, p.id`,
), "% Ribeiro")
assert.NotEqual(t, nil, err)
msg := err.Error()
for _, str := range []string{"foo", "*ksql.User"} {
assert.Equal(t, true, strings.Contains(msg, str), fmt.Sprintf("missing expected substr '%s' in error message: '%s'", str, msg))
}
})
})
}) })
} }
} }