diff --git a/ksql.go b/ksql.go index 1c30a8b..52674db 100644 --- a/ksql.go +++ b/ksql.go @@ -813,11 +813,6 @@ func (nopScanner) Scan(value interface{}) error { } func scanRows(dialect dialect, rows *sql.Rows, record interface{}) error { - names, err := rows.Columns() - if err != nil { - return err - } - v := reflect.ValueOf(record) t := v.Type() if t.Kind() != reflect.Ptr { @@ -833,6 +828,53 @@ func scanRows(dialect dialect, rows *sql.Rows, record interface{}) error { info := structs.GetTagInfo(t) + var scanArgs []interface{} + if info.IsNestedStruct { + // This version is positional meaning that it expect the arguments + // to follow an specific order. It's ok because we don't allow the + // user to type the "SELECT" part of the query for nested structs. + scanArgs = getScanArgsForNestedStructs(dialect, rows, t, v, info) + } else { + names, err := rows.Columns() + if err != nil { + return err + } + // Since this version uses the names of the columns it works + // with any order of attributes/columns. + scanArgs = getScanArgsFromNames(dialect, names, v, info) + } + + return rows.Scan(scanArgs...) +} + +func getScanArgsForNestedStructs(dialect dialect, rows *sql.Rows, t reflect.Type, v reflect.Value, info structs.StructInfo) []interface{} { + scanArgs := []interface{}{} + for i := 0; i < v.NumField(); i++ { + // TODO(vingarcia00): Handle case where type is pointer + nestedStructInfo := structs.GetTagInfo(t.Field(i).Type) + nestedStructValue := v.Field(i) + for j := 0; j < nestedStructValue.NumField(); j++ { + fieldInfo := nestedStructInfo.ByIndex(j) + + valueScanner := nopScannerValue + if fieldInfo.Valid { + valueScanner = nestedStructValue.Field(fieldInfo.Index).Addr().Interface() + if fieldInfo.SerializeAsJSON { + valueScanner = &jsonSerializable{ + DriverName: dialect.DriverName(), + Attr: valueScanner, + } + } + } + + scanArgs = append(scanArgs, valueScanner) + } + } + + return scanArgs +} + +func getScanArgsFromNames(dialect dialect, names []string, v reflect.Value, info structs.StructInfo) []interface{} { scanArgs := []interface{}{} for _, name := range names { fieldInfo := info.ByName(name) @@ -851,7 +893,7 @@ func scanRows(dialect dialect, rows *sql.Rows, record interface{}) error { scanArgs = append(scanArgs, valueScanner) } - return rows.Scan(scanArgs...) + return scanArgs } func buildSingleKeyDeleteQuery( @@ -923,18 +965,54 @@ func buildSelectQuery( dialect dialect, structType reflect.Type, selectQueryCache map[reflect.Type]string, -) (string, error) { +) (query string, err error) { if selectQuery, found := selectQueryCache[structType]; found { return selectQuery, nil } info := structs.GetTagInfo(structType) - var fields []string - for _, field := range info.Fields() { - fields = append(fields, dialect.Escape(field.Name)) + if info.IsNestedStruct { + query, err = buildSelectQueryForNestedStructs(dialect, structType, info) + if err != nil { + return "", err + } + } else { + query = buildSelectQueryForPlainStructs(dialect, structType, info) } - query := "SELECT " + strings.Join(fields, ", ") + " " selectQueryCache[structType] = query return query, nil } + +func buildSelectQueryForPlainStructs( + dialect dialect, + structType reflect.Type, + info structs.StructInfo, +) string { + var fields []string + for i := 0; i < structType.NumField(); i++ { + fields = append(fields, dialect.Escape(info.ByIndex(i).Name)) + } + + return "SELECT " + strings.Join(fields, ", ") + " " +} + +func buildSelectQueryForNestedStructs( + dialect dialect, + structType reflect.Type, + info structs.StructInfo, +) (string, error) { + var fields []string + for i := 0; i < structType.NumField(); i++ { + nestedStructName := info.ByIndex(i).Name + nestedStructInfo := structs.GetTagInfo(structType.Field(i).Type) + for j := 0; j < structType.Field(i).Type.NumField(); j++ { + fields = append( + fields, + dialect.Escape(nestedStructName)+"."+dialect.Escape(nestedStructInfo.ByIndex(j).Name), + ) + } + } + + return "SELECT " + strings.Join(fields, ", ") + " ", nil +} diff --git a/ksql_test.go b/ksql_test.go index 19a6cad..b7dfe4e 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -34,6 +34,12 @@ type Address struct { Country string `json:"country"` } +type Post struct { + ID int `ksql:"id"` + UserID uint `ksql:"user_id"` + Title string `ksql:"title"` +} + func TestQuery(t *testing.T) { for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { @@ -53,7 +59,7 @@ func TestQuery(t *testing.T) { for _, variation := range variations { t.Run(variation.desc, func(t *testing.T) { t.Run("using slice of structs", func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -123,7 +129,7 @@ func TestQuery(t *testing.T) { }) t.Run("using slice of pointers to structs", func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -190,12 +196,66 @@ func TestQuery(t *testing.T) { assert.Equal(t, "Bia Garcia", users[1].Name) assert.Equal(t, "BR", users[1].Address.Country) }) + + t.Run("should query joined tables correctly", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + // This test only makes sense with no query prefix + if variation.queryPrefix != "" { + return + } + + _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('João Ribeiro', 0, '{"country":"US"}')`) + assert.Equal(t, nil, err) + var joaoID uint + db.QueryRow(`SELECT id FROM users WHERE name = 'João Ribeiro'`).Scan(&joaoID) + + _, err = db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia Ribeiro', 0, '{"country":"BR"}')`) + assert.Equal(t, nil, err) + var biaID uint + db.QueryRow(`SELECT id FROM users WHERE name = 'Bia Ribeiro'`).Scan(&biaID) + + _, err = db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, biaID, `, 'Bia Post1')`)) + assert.Equal(t, nil, err) + _, err = db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, biaID, `, 'Bia Post2')`)) + assert.Equal(t, nil, err) + _, err = db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, joaoID, `, 'João Post1')`)) + assert.Equal(t, nil, err) + + ctx := context.Background() + c := newTestDB(db, driver, "users") + var rows []*struct { + User User `tablename:"u"` + Post Post `tablename:"p"` + } + err = c.Query(ctx, &rows, fmt.Sprint( + `FROM users u JOIN posts p ON p.user_id = u.id`, + ` WHERE u.name like `, c.dialect.Placeholder(0), + ` ORDER BY u.id, p.id`, + ), "% Ribeiro") + + assert.Equal(t, nil, err) + assert.Equal(t, 3, len(rows)) + + assert.Equal(t, joaoID, rows[0].User.ID) + assert.Equal(t, "João Ribeiro", rows[0].User.Name) + assert.Equal(t, "João Post1", rows[0].Post.Title) + + assert.Equal(t, biaID, rows[1].User.ID) + assert.Equal(t, "Bia Ribeiro", rows[1].User.Name) + assert.Equal(t, "Bia Post1", rows[1].Post.Title) + + assert.Equal(t, biaID, rows[2].User.ID) + assert.Equal(t, "Bia Ribeiro", rows[2].User.Name) + assert.Equal(t, "Bia Post2", rows[2].Post.Title) + }) }) }) } t.Run("testing error cases", func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -258,7 +318,7 @@ func TestQueryOne(t *testing.T) { }, } for _, variation := range variations { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -358,7 +418,7 @@ func TestInsert(t *testing.T) { for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { t.Run("using slice of structs", func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -428,7 +488,7 @@ func TestInsert(t *testing.T) { }) t.Run("testing error cases", func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -485,7 +545,7 @@ func TestInsert(t *testing.T) { func TestDelete(t *testing.T) { for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -628,7 +688,7 @@ func TestDelete(t *testing.T) { func TestUpdate(t *testing.T) { for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -813,7 +873,7 @@ func TestQueryChunks(t *testing.T) { for _, variation := range variations { t.Run(variation.desc, func(t *testing.T) { t.Run("should query a single row correctly", func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -853,7 +913,7 @@ func TestQueryChunks(t *testing.T) { }) t.Run("should query one chunk correctly", func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -895,7 +955,7 @@ func TestQueryChunks(t *testing.T) { }) t.Run("should query chunks of 1 correctly", func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -937,7 +997,7 @@ func TestQueryChunks(t *testing.T) { }) t.Run("should load partially filled chunks correctly", func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -978,7 +1038,7 @@ func TestQueryChunks(t *testing.T) { }) t.Run("should abort the first iteration when the callback returns an ErrAbortIteration", func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -1017,7 +1077,7 @@ func TestQueryChunks(t *testing.T) { }) t.Run("should abort the last iteration when the callback returns an ErrAbortIteration", func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -1060,7 +1120,7 @@ func TestQueryChunks(t *testing.T) { }) t.Run("should return error if the callback returns an error in the first iteration", func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -1099,7 +1159,7 @@ func TestQueryChunks(t *testing.T) { }) t.Run("should return error if the callback returns an error in the last iteration", func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -1213,7 +1273,7 @@ func TestTransaction(t *testing.T) { for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { t.Run("should query a single row correctly", func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -1240,7 +1300,7 @@ func TestTransaction(t *testing.T) { }) t.Run("should rollback when there are errors", func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -1281,7 +1341,7 @@ func TestTransaction(t *testing.T) { func TestScanRows(t *testing.T) { t.Run("should scan users correctly", func(t *testing.T) { - err := createTable("sqlite3") + err := createTables("sqlite3") if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -1310,7 +1370,7 @@ func TestScanRows(t *testing.T) { }) t.Run("should ignore extra columns from query", func(t *testing.T) { - err := createTable("sqlite3") + err := createTables("sqlite3") if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -1342,7 +1402,7 @@ func TestScanRows(t *testing.T) { }) t.Run("should report error for closed rows", func(t *testing.T) { - err := createTable("sqlite3") + err := createTables("sqlite3") if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -1363,7 +1423,7 @@ func TestScanRows(t *testing.T) { }) t.Run("should report if record is not a pointer", func(t *testing.T) { - err := createTable("sqlite3") + err := createTables("sqlite3") if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -1382,7 +1442,7 @@ func TestScanRows(t *testing.T) { }) t.Run("should report if record is not a pointer to struct", func(t *testing.T) { - err := createTable("sqlite3") + err := createTables("sqlite3") if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -1408,7 +1468,7 @@ var connectionString = map[string]string{ "sqlserver": "sqlserver://sa:Sqls3rv3r@127.0.0.1:1433?databaseName=ksql", } -func createTable(driver string) error { +func createTables(driver string) error { connStr := connectionString[driver] if connStr == "" { return fmt.Errorf("unsupported driver: '%s'", driver) @@ -1456,6 +1516,38 @@ func createTable(driver string) error { return fmt.Errorf("failed to create new users table: %s", err.Error()) } + db.Exec(`DROP TABLE posts`) + + switch driver { + case "sqlite3": + _, err = db.Exec(`CREATE TABLE posts ( + id INTEGER PRIMARY KEY, + user_id INTEGER, + title TEXT + )`) + case "postgres": + _, err = db.Exec(`CREATE TABLE posts ( + id serial PRIMARY KEY, + user_id INT, + title VARCHAR(50) + )`) + case "mysql": + _, err = db.Exec(`CREATE TABLE posts ( + id INT AUTO_INCREMENT PRIMARY KEY, + user_id INT, + title VARCHAR(50) + )`) + case "sqlserver": + _, err = db.Exec(`CREATE TABLE posts ( + id INT IDENTITY(1,1) PRIMARY KEY, + user_id INT, + title VARCHAR(50) + )`) + } + if err != nil { + return fmt.Errorf("failed to create new users table: %s", err.Error()) + } + return nil } diff --git a/structs/structs.go b/structs/structs.go index 63d7d57..3106d62 100644 --- a/structs/structs.go +++ b/structs/structs.go @@ -8,49 +8,56 @@ import ( "github.com/pkg/errors" ) -type structInfo struct { - byIndex map[int]*fieldInfo - byName map[string]*fieldInfo +// StructInfo stores metainformation of the struct +// parser in order to help the ksql library to work +// efectively and efficiently with reflection. +type StructInfo struct { + IsNestedStruct bool + byIndex map[int]*FieldInfo + byName map[string]*FieldInfo } -type fieldInfo struct { +// FieldInfo contains reflection and tags +// information regarding a specific field +// of a struct. +type FieldInfo struct { Name string Index int Valid bool SerializeAsJSON bool } -func (s structInfo) ByIndex(idx int) *fieldInfo { +// ByIndex returns either the *FieldInfo of a valid +// empty struct with Valid set to false +func (s StructInfo) ByIndex(idx int) *FieldInfo { field, found := s.byIndex[idx] if !found { - return &fieldInfo{} + return &FieldInfo{} } return field } -func (s structInfo) ByName(name string) *fieldInfo { +// ByName returns either the *FieldInfo of a valid +// empty struct with Valid set to false +func (s StructInfo) ByName(name string) *FieldInfo { field, found := s.byName[name] if !found { - return &fieldInfo{} + return &FieldInfo{} } return field } -func (s structInfo) Add(field fieldInfo) { +func (s StructInfo) add(field FieldInfo) { field.Valid = true s.byIndex[field.Index] = &field s.byName[field.Name] = &field } -func (s structInfo) Fields() map[int]*fieldInfo { - return s.byIndex -} - // This cache is kept as a pkg variable // because the total number of types on a program // should be finite. So keeping a single cache here // works fine. -var tagInfoCache = map[reflect.Type]structInfo{} +var tagInfoCache = map[reflect.Type]StructInfo{} // GetTagInfo efficiently returns the type information // using a global private cache @@ -58,16 +65,17 @@ var tagInfoCache = map[reflect.Type]structInfo{} // In the future we might move this cache inside // a struct, but for now this accessor is the one // we are using -func GetTagInfo(key reflect.Type) structInfo { +func GetTagInfo(key reflect.Type) StructInfo { return getCachedTagInfo(tagInfoCache, key) } -func getCachedTagInfo(tagInfoCache map[reflect.Type]structInfo, key reflect.Type) structInfo { - info, found := tagInfoCache[key] - if !found { - info = getTagNames(key) - tagInfoCache[key] = info +func getCachedTagInfo(tagInfoCache map[reflect.Type]StructInfo, key reflect.Type) StructInfo { + if info, found := tagInfoCache[key]; found { + return info } + + info := getTagNames(key) + tagInfoCache[key] = info return info } @@ -291,10 +299,10 @@ func FillSliceWith(entities interface{}, dbRows []map[string]interface{}) error // // This should save several calls to `Field(i).Tag.Get("foo")` // which improves performance by a lot. -func getTagNames(t reflect.Type) structInfo { - info := structInfo{ - byIndex: map[int]*fieldInfo{}, - byName: map[string]*fieldInfo{}, +func getTagNames(t reflect.Type) StructInfo { + info := StructInfo{ + byIndex: map[int]*FieldInfo{}, + byName: map[string]*FieldInfo{}, } for i := 0; i < t.NumField(); i++ { name := t.Field(i).Tag.Get("ksql") @@ -309,13 +317,36 @@ func getTagNames(t reflect.Type) structInfo { serializeAsJSON = tags[1] == "json" } - info.Add(fieldInfo{ + info.add(FieldInfo{ Name: name, Index: i, SerializeAsJSON: serializeAsJSON, }) } + // If there were `ksql` tags present, then we are finished: + if len(info.byIndex) > 0 { + return info + } + + // If there are no `ksql` tags in the struct, lets assume + // it is a struct tagged with `tablename` for allowing JOINs + for i := 0; i < t.NumField(); i++ { + name := t.Field(i).Tag.Get("tablename") + if name == "" { + continue + } + + info.add(FieldInfo{ + Name: name, + Index: i, + }) + } + + if len(info.byIndex) > 0 { + info.IsNestedStruct = true + } + return info }