From 9e4583c3f8cc99b331b444bcbbf6642023f930d5 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= <vingarcia00@gmail.com>
Date: Sun, 23 May 2021 12:25:35 -0300
Subject: [PATCH] Add error check for preventing reflection panics in nested
 structs

---
 ksql.go      | 10 +++++++++-
 ksql_test.go | 46 ++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 55 insertions(+), 1 deletion(-)

diff --git a/ksql.go b/ksql.go
index 52674db..5ada2ef 100644
--- a/ksql.go
+++ b/ksql.go
@@ -1005,7 +1005,15 @@ func buildSelectQueryForNestedStructs(
 	var fields []string
 	for i := 0; i < structType.NumField(); i++ {
 		nestedStructName := info.ByIndex(i).Name
-		nestedStructInfo := structs.GetTagInfo(structType.Field(i).Type)
+		nestedStructType := structType.Field(i).Type
+		if nestedStructType.Kind() != reflect.Struct {
+			return "", fmt.Errorf(
+				"expected nested struct with `tablename:\"%s\"` to be a kind of Struct, but got %v",
+				nestedStructName, nestedStructType,
+			)
+		}
+
+		nestedStructInfo := structs.GetTagInfo(nestedStructType)
 		for j := 0; j < structType.Field(i).Type.NumField(); j++ {
 			fields = append(
 				fields,
diff --git a/ksql_test.go b/ksql_test.go
index be3aa30..94fa122 100644
--- a/ksql_test.go
+++ b/ksql_test.go
@@ -351,6 +351,52 @@ func TestQuery(t *testing.T) {
 					assert.NotEqual(t, nil, err)
 				})
 			})
+
+			t.Run("should report error for nested structs with invalid types", func(t *testing.T) {
+				t.Run("int", func(t *testing.T) {
+					db := connectDB(t, driver)
+					defer db.Close()
+
+					ctx := context.Background()
+					c := newTestDB(db, driver, "users")
+					var rows []struct {
+						Foo int `tablename:"foo"`
+					}
+					err := c.Query(ctx, &rows, fmt.Sprint(
+						`FROM users u JOIN posts p ON p.user_id = u.id`,
+						` WHERE u.name like `, c.dialect.Placeholder(0),
+						` ORDER BY u.id, p.id`,
+					), "% Ribeiro")
+
+					assert.NotEqual(t, nil, err)
+					msg := err.Error()
+					for _, str := range []string{"foo", "int"} {
+						assert.Equal(t, true, strings.Contains(msg, str), fmt.Sprintf("missing expected substr '%s' in error message: '%s'", str, msg))
+					}
+				})
+
+				t.Run("*struct", func(t *testing.T) {
+					db := connectDB(t, driver)
+					defer db.Close()
+
+					ctx := context.Background()
+					c := newTestDB(db, driver, "users")
+					var rows []struct {
+						Foo *User `tablename:"foo"`
+					}
+					err := c.Query(ctx, &rows, fmt.Sprint(
+						`FROM users u JOIN posts p ON p.user_id = u.id`,
+						` WHERE u.name like `, c.dialect.Placeholder(0),
+						` ORDER BY u.id, p.id`,
+					), "% Ribeiro")
+
+					assert.NotEqual(t, nil, err)
+					msg := err.Error()
+					for _, str := range []string{"foo", "*ksql.User"} {
+						assert.Equal(t, true, strings.Contains(msg, str), fmt.Sprintf("missing expected substr '%s' in error message: '%s'", str, msg))
+					}
+				})
+			})
 		})
 	}
 }