Add private scanRows() function

pull/2/head
Vinícius Garcia 2020-11-21 02:35:56 -03:00
parent c7e743527f
commit b6f69d52aa
2 changed files with 73 additions and 16 deletions

View File

@ -2,6 +2,7 @@ package kissorm
import (
"context"
"database/sql"
"fmt"
"reflect"
"strings"
@ -338,11 +339,7 @@ func StructToMap(obj interface{}) (map[string]interface{}, error) {
return nil, fmt.Errorf("input must be a struct or struct pointer")
}
info, found := tagInfoCache[t]
if !found {
info = getTagNames(t)
tagInfoCache[t] = info
}
info := getTagInfoWithCache(tagInfoCache, t)
m := map[string]interface{}{}
for i := 0; i < v.NumField(); i++ {
@ -411,11 +408,7 @@ func FillStructWith(record interface{}, dbRow map[string]interface{}) error {
)
}
info, found := tagInfoCache[t]
if !found {
info = getTagNames(t)
tagInfoCache[t] = info
}
info := getTagInfoWithCache(tagInfoCache, t)
for colName, attr := range dbRow {
attrValue := reflect.ValueOf(attr)
@ -457,12 +450,6 @@ func FillSliceWith(entities interface{}, dbRows []map[string]interface{}) error
return fmt.Errorf("FillSliceWith: %s", err.Error())
}
info, found := tagInfoCache[structType]
if !found {
info = getTagNames(structType)
tagInfoCache[structType] = info
}
slice := sliceRef.Elem()
for idx, row := range dbRows {
if slice.Len() <= idx {
@ -543,3 +530,41 @@ func parseInputFunc(fn interface{}) (reflect.Type, error) {
return argsType, nil
}
func scanRows(rows *sql.Rows, record interface{}) error {
names, err := rows.Columns() // rows.QueryContext(ctx, query string, args ...interface{}) (*Rows, error)
if err != nil {
return err
}
v := reflect.ValueOf(record)
t := v.Type()
if t.Kind() != reflect.Ptr {
return fmt.Errorf("kissorm: expected to receive a pointer to struct, but got: %T", record)
}
v = v.Elem()
t = t.Elem()
if t.Kind() != reflect.Struct {
return fmt.Errorf("kissorm: expected to receive a pointer to slice of structs, but got: %T", record)
}
info := getTagInfoWithCache(tagInfoCache, t)
scanArgs := []interface{}{}
for _, name := range names {
scanArgs = append(scanArgs, v.Field(info.Index[name]).Addr().Interface())
}
return rows.Scan(scanArgs...)
}
func getTagInfoWithCache(tagInfoCache map[reflect.Type]structInfo, key reflect.Type) structInfo {
info, found := tagInfoCache[key]
if !found {
info = getTagNames(key)
tagInfoCache[key] = info
}
return info
}

View File

@ -823,6 +823,38 @@ func TestFillSliceWith(t *testing.T) {
})
}
func TestScanRows(t *testing.T) {
t.Run("should scan users correctly", func(t *testing.T) {
err := createTable()
if err != nil {
t.Fatal("could not create test table!")
}
ctx := context.TODO()
db := connectDB(t)
defer db.Close()
c := Client{
db: db,
tableName: "users",
}
_ = c.Insert(ctx, &User{Name: "User1", Age: 22})
_ = c.Insert(ctx, &User{Name: "User2", Age: 14})
_ = c.Insert(ctx, &User{Name: "User3", Age: 43})
rows, err := db.DB().QueryContext(ctx, "select * from users where name='User2'")
assert.Equal(t, nil, err)
assert.Equal(t, true, rows.Next())
var u User
err = scanRows(rows, &u)
assert.Equal(t, nil, err)
assert.Equal(t, "User2", u.Name)
assert.Equal(t, 14, u.Age)
})
}
func createTable() error {
db, err := gorm.Open("sqlite3", "/tmp/test.db")
if err != nil {