From 9e4583c3f8cc99b331b444bcbbf6642023f930d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= <vingarcia00@gmail.com> Date: Sun, 23 May 2021 12:25:35 -0300 Subject: [PATCH] Add error check for preventing reflection panics in nested structs --- ksql.go | 10 +++++++++- ksql_test.go | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/ksql.go b/ksql.go index 52674db..5ada2ef 100644 --- a/ksql.go +++ b/ksql.go @@ -1005,7 +1005,15 @@ func buildSelectQueryForNestedStructs( var fields []string for i := 0; i < structType.NumField(); i++ { nestedStructName := info.ByIndex(i).Name - nestedStructInfo := structs.GetTagInfo(structType.Field(i).Type) + nestedStructType := structType.Field(i).Type + if nestedStructType.Kind() != reflect.Struct { + return "", fmt.Errorf( + "expected nested struct with `tablename:\"%s\"` to be a kind of Struct, but got %v", + nestedStructName, nestedStructType, + ) + } + + nestedStructInfo := structs.GetTagInfo(nestedStructType) for j := 0; j < structType.Field(i).Type.NumField(); j++ { fields = append( fields, diff --git a/ksql_test.go b/ksql_test.go index be3aa30..94fa122 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -351,6 +351,52 @@ func TestQuery(t *testing.T) { assert.NotEqual(t, nil, err) }) }) + + t.Run("should report error for nested structs with invalid types", func(t *testing.T) { + t.Run("int", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestDB(db, driver, "users") + var rows []struct { + Foo int `tablename:"foo"` + } + err := c.Query(ctx, &rows, fmt.Sprint( + `FROM users u JOIN posts p ON p.user_id = u.id`, + ` WHERE u.name like `, c.dialect.Placeholder(0), + ` ORDER BY u.id, p.id`, + ), "% Ribeiro") + + assert.NotEqual(t, nil, err) + msg := err.Error() + for _, str := range []string{"foo", "int"} { + assert.Equal(t, true, strings.Contains(msg, str), fmt.Sprintf("missing expected substr '%s' in error message: '%s'", str, msg)) + } + }) + + t.Run("*struct", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestDB(db, driver, "users") + var rows []struct { + Foo *User `tablename:"foo"` + } + err := c.Query(ctx, &rows, fmt.Sprint( + `FROM users u JOIN posts p ON p.user_id = u.id`, + ` WHERE u.name like `, c.dialect.Placeholder(0), + ` ORDER BY u.id, p.id`, + ), "% Ribeiro") + + assert.NotEqual(t, nil, err) + msg := err.Error() + for _, str := range []string{"foo", "*ksql.User"} { + assert.Equal(t, true, strings.Contains(msg, str), fmt.Sprintf("missing expected substr '%s' in error message: '%s'", str, msg)) + } + }) + }) }) } }