Extract struct helper functions into the structs package

pull/2/head
Vinícius Garcia 2021-02-16 00:01:19 -03:00
parent 203b141aca
commit 304e5bde49
4 changed files with 47 additions and 31 deletions

View File

@ -9,6 +9,7 @@ import (
"github.com/tj/assert" "github.com/tj/assert"
"github.com/vingarcia/kissorm" "github.com/vingarcia/kissorm"
"github.com/vingarcia/kissorm/nullable" "github.com/vingarcia/kissorm/nullable"
"github.com/vingarcia/kissorm/structs"
) )
func TestCreateUser(t *testing.T) { func TestCreateUser(t *testing.T) {
@ -58,7 +59,7 @@ func TestCreateUser(t *testing.T) {
// //
// If you are inserting an anonymous struct (not usual) this function // If you are inserting an anonymous struct (not usual) this function
// can make your tests shorter: // can make your tests shorter:
uMap, err := kissorm.StructToMap(record) uMap, err := structs.StructToMap(record)
if err != nil { if err != nil {
return err return err
} }
@ -95,7 +96,7 @@ func TestUpdateUserScore(t *testing.T) {
DoAndReturn(func(ctx context.Context, result interface{}, query string, params ...interface{}) error { DoAndReturn(func(ctx context.Context, result interface{}, query string, params ...interface{}) error {
// This function will use reflection to fill the // This function will use reflection to fill the
// struct fields with the values from the map // struct fields with the values from the map
return kissorm.FillStructWith(result, map[string]interface{}{ return structs.FillStructWith(result, map[string]interface{}{
// Use int this map the keys you set on the kissorm tags, e.g. `kissorm:"score"` // Use int this map the keys you set on the kissorm tags, e.g. `kissorm:"score"`
// Each of these fields represent the database rows returned // Each of these fields represent the database rows returned
// by the query. // by the query.
@ -138,7 +139,7 @@ func TestListUsers(t *testing.T) {
DoAndReturn(func(ctx context.Context, result interface{}, query string, params ...interface{}) error { DoAndReturn(func(ctx context.Context, result interface{}, query string, params ...interface{}) error {
// This function will use reflection to fill the // This function will use reflection to fill the
// struct fields with the values from the map // struct fields with the values from the map
return kissorm.FillStructWith(result, map[string]interface{}{ return structs.FillStructWith(result, map[string]interface{}{
// Use int this map the keys you set on the kissorm tags, e.g. `kissorm:"score"` // Use int this map the keys you set on the kissorm tags, e.g. `kissorm:"score"`
// Each of these fields represent the database rows returned // Each of these fields represent the database rows returned
// by the query. // by the query.
@ -147,7 +148,7 @@ func TestListUsers(t *testing.T) {
}), }),
usersTableMock.EXPECT().Query(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). usersTableMock.EXPECT().Query(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
DoAndReturn(func(ctx context.Context, results interface{}, query string, params ...interface{}) error { DoAndReturn(func(ctx context.Context, results interface{}, query string, params ...interface{}) error {
return kissorm.FillSliceWith(results, []map[string]interface{}{ return structs.FillSliceWith(results, []map[string]interface{}{
{ {
"id": 1, "id": 1,
"name": "fake name", "name": "fake name",

View File

@ -8,6 +8,7 @@ import (
"strings" "strings"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/vingarcia/kissorm/structs"
) )
// DB represents the kissorm client responsible for // DB represents the kissorm client responsible for
@ -110,7 +111,7 @@ func (c DB) Query(
} }
sliceType := slicePtrType.Elem() sliceType := slicePtrType.Elem()
slice := slicePtr.Elem() slice := slicePtr.Elem()
structType, isSliceOfPtrs, err := decodeAsSliceOfStructs(sliceType) structType, isSliceOfPtrs, err := structs.DecodeAsSliceOfStructs(sliceType)
if err != nil { if err != nil {
return err return err
} }
@ -230,7 +231,7 @@ func (c DB) QueryChunks(
chunk := reflect.MakeSlice(chunkType, 0, parser.ChunkSize) chunk := reflect.MakeSlice(chunkType, 0, parser.ChunkSize)
structType, isSliceOfPtrs, err := decodeAsSliceOfStructs(chunkType) structType, isSliceOfPtrs, err := structs.DecodeAsSliceOfStructs(chunkType)
if err != nil { if err != nil {
return err return err
} }
@ -368,7 +369,7 @@ func (c DB) insertWithReturningID(
if err = assertStructPtr(t); err != nil { if err = assertStructPtr(t); err != nil {
return errors.Wrap(err, "can't write id field") return errors.Wrap(err, "can't write id field")
} }
info := getCachedTagInfo(tagInfoCache, t.Elem()) info := structs.GetTagInfo(t.Elem())
var scanFields []interface{} var scanFields []interface{}
for _, id := range idNames { for _, id := range idNames {
@ -403,7 +404,7 @@ func (c DB) insertWithLastInsertID(
return errors.Wrap(err, "can't write to `"+idName+"` field") return errors.Wrap(err, "can't write to `"+idName+"` field")
} }
info := getCachedTagInfo(tagInfoCache, t.Elem()) info := structs.GetTagInfo(t.Elem())
id, err := result.LastInsertId() id, err := result.LastInsertId()
if err != nil { if err != nil {
@ -485,7 +486,7 @@ func normalizeIDsAsMaps(idNames []string, ids []interface{}) ([]map[string]inter
t := reflect.TypeOf(ids[i]) t := reflect.TypeOf(ids[i])
switch t.Kind() { switch t.Kind() {
case reflect.Struct: case reflect.Struct:
m, err := StructToMap(ids[i]) m, err := structs.StructToMap(ids[i])
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "could not get ID(s) from record on idx %d", i) return nil, errors.Wrapf(err, "could not get ID(s) from record on idx %d", i)
} }
@ -542,7 +543,7 @@ func buildInsertQuery(
record interface{}, record interface{},
idFieldNames ...string, idFieldNames ...string,
) (query string, params []interface{}, err error) { ) (query string, params []interface{}, err error) {
recordMap, err := StructToMap(record) recordMap, err := structs.StructToMap(record)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
@ -588,7 +589,7 @@ func buildUpdateQuery(
record interface{}, record interface{},
idFieldNames ...string, idFieldNames ...string,
) (query string, args []interface{}, err error) { ) (query string, args []interface{}, err error) {
recordMap, err := StructToMap(record) recordMap, err := structs.StructToMap(record)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
@ -683,12 +684,6 @@ func (c DB) Transaction(ctx context.Context, fn func(ORMProvider) error) error {
} }
} }
// This cache is kept as a pkg variable
// because the total number of types on a program
// should be finite. So keeping a single cache here
// works fine.
var tagInfoCache = map[reflect.Type]structInfo{}
var errType = reflect.TypeOf(new(error)).Elem() var errType = reflect.TypeOf(new(error)).Elem()
func parseInputFunc(fn interface{}) (reflect.Type, error) { func parseInputFunc(fn interface{}) (reflect.Type, error) {
@ -744,7 +739,7 @@ func scanRows(rows *sql.Rows, record interface{}) error {
return fmt.Errorf("kissorm: expected to receive a pointer to slice of structs, but got: %T", record) return fmt.Errorf("kissorm: expected to receive a pointer to slice of structs, but got: %T", record)
} }
info := getCachedTagInfo(tagInfoCache, t) info := structs.GetTagInfo(t)
scanArgs := []interface{}{} scanArgs := []interface{}{}
for _, name := range names { for _, name := range names {
@ -760,15 +755,6 @@ func scanRows(rows *sql.Rows, record interface{}) error {
return rows.Scan(scanArgs...) return rows.Scan(scanArgs...)
} }
func getCachedTagInfo(tagInfoCache map[reflect.Type]structInfo, key reflect.Type) structInfo {
info, found := tagInfoCache[key]
if !found {
info = getTagNames(key)
tagInfoCache[key] = info
}
return info
}
func buildSingleKeyDeleteQuery( func buildSingleKeyDeleteQuery(
dialect dialect, dialect dialect,
table string, table string,

View File

@ -1,4 +1,4 @@
package kissorm package structs
import ( import (
"fmt" "fmt"
@ -12,6 +12,31 @@ type structInfo struct {
Index map[string]int Index map[string]int
} }
// This cache is kept as a pkg variable
// because the total number of types on a program
// should be finite. So keeping a single cache here
// works fine.
var tagInfoCache = map[reflect.Type]structInfo{}
// GetTagInfo efficiently returns the type information
// using a global private cache
//
// In the future we might move this cache inside
// a struct, but for now this accessor is the one
// we are using
func GetTagInfo(key reflect.Type) structInfo {
return getCachedTagInfo(tagInfoCache, key)
}
func getCachedTagInfo(tagInfoCache map[reflect.Type]structInfo, key reflect.Type) structInfo {
info, found := tagInfoCache[key]
if !found {
info = getTagNames(key)
tagInfoCache[key] = info
}
return info
}
// StructToMap converts any struct type to a map based on // StructToMap converts any struct type to a map based on
// the tag named `kissorm`, i.e. `kissorm:"map_key_name"` // the tag named `kissorm`, i.e. `kissorm:"map_key_name"`
// //
@ -200,7 +225,7 @@ func FillSliceWith(entities interface{}, dbRows []map[string]interface{}) error
) )
} }
structType, isSliceOfPtrs, err := decodeAsSliceOfStructs(sliceType.Elem()) structType, isSliceOfPtrs, err := DecodeAsSliceOfStructs(sliceType.Elem())
if err != nil { if err != nil {
return errors.Wrap(err, "FillSliceWith") return errors.Wrap(err, "FillSliceWith")
} }
@ -249,7 +274,11 @@ func getTagNames(t reflect.Type) structInfo {
return info return info
} }
func decodeAsSliceOfStructs(slice reflect.Type) ( // DecodeAsSliceOfStructs makes several checks
// while decoding an input type and returns
// useful information so that it is easier
// to manipulate the original slice later.
func DecodeAsSliceOfStructs(slice reflect.Type) (
structType reflect.Type, structType reflect.Type,
isSliceOfPtrs bool, isSliceOfPtrs bool,
err error, err error,

View File

@ -1,4 +1,4 @@
package kissorm package structs
import ( import (
"testing" "testing"