mirror of https://github.com/VinGarcia/ksql.git
Add error check for preventing reflection panics in nested structs
parent
6bd61346d9
commit
9e4583c3f8
10
ksql.go
10
ksql.go
|
@ -1005,7 +1005,15 @@ func buildSelectQueryForNestedStructs(
|
||||||
var fields []string
|
var fields []string
|
||||||
for i := 0; i < structType.NumField(); i++ {
|
for i := 0; i < structType.NumField(); i++ {
|
||||||
nestedStructName := info.ByIndex(i).Name
|
nestedStructName := info.ByIndex(i).Name
|
||||||
nestedStructInfo := structs.GetTagInfo(structType.Field(i).Type)
|
nestedStructType := structType.Field(i).Type
|
||||||
|
if nestedStructType.Kind() != reflect.Struct {
|
||||||
|
return "", fmt.Errorf(
|
||||||
|
"expected nested struct with `tablename:\"%s\"` to be a kind of Struct, but got %v",
|
||||||
|
nestedStructName, nestedStructType,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
nestedStructInfo := structs.GetTagInfo(nestedStructType)
|
||||||
for j := 0; j < structType.Field(i).Type.NumField(); j++ {
|
for j := 0; j < structType.Field(i).Type.NumField(); j++ {
|
||||||
fields = append(
|
fields = append(
|
||||||
fields,
|
fields,
|
||||||
|
|
46
ksql_test.go
46
ksql_test.go
|
@ -351,6 +351,52 @@ func TestQuery(t *testing.T) {
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("should report error for nested structs with invalid types", func(t *testing.T) {
|
||||||
|
t.Run("int", func(t *testing.T) {
|
||||||
|
db := connectDB(t, driver)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
c := newTestDB(db, driver, "users")
|
||||||
|
var rows []struct {
|
||||||
|
Foo int `tablename:"foo"`
|
||||||
|
}
|
||||||
|
err := c.Query(ctx, &rows, fmt.Sprint(
|
||||||
|
`FROM users u JOIN posts p ON p.user_id = u.id`,
|
||||||
|
` WHERE u.name like `, c.dialect.Placeholder(0),
|
||||||
|
` ORDER BY u.id, p.id`,
|
||||||
|
), "% Ribeiro")
|
||||||
|
|
||||||
|
assert.NotEqual(t, nil, err)
|
||||||
|
msg := err.Error()
|
||||||
|
for _, str := range []string{"foo", "int"} {
|
||||||
|
assert.Equal(t, true, strings.Contains(msg, str), fmt.Sprintf("missing expected substr '%s' in error message: '%s'", str, msg))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("*struct", func(t *testing.T) {
|
||||||
|
db := connectDB(t, driver)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
c := newTestDB(db, driver, "users")
|
||||||
|
var rows []struct {
|
||||||
|
Foo *User `tablename:"foo"`
|
||||||
|
}
|
||||||
|
err := c.Query(ctx, &rows, fmt.Sprint(
|
||||||
|
`FROM users u JOIN posts p ON p.user_id = u.id`,
|
||||||
|
` WHERE u.name like `, c.dialect.Placeholder(0),
|
||||||
|
` ORDER BY u.id, p.id`,
|
||||||
|
), "% Ribeiro")
|
||||||
|
|
||||||
|
assert.NotEqual(t, nil, err)
|
||||||
|
msg := err.Error()
|
||||||
|
for _, str := range []string{"foo", "*ksql.User"} {
|
||||||
|
assert.Equal(t, true, strings.Contains(msg, str), fmt.Sprintf("missing expected substr '%s' in error message: '%s'", str, msg))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue