mirror of https://github.com/VinGarcia/ksql.git
Decouple ksql.DB from TagInfoCache so we can replace it during tests
parent
0e95506343
commit
74cb87bea0
|
@ -69,12 +69,24 @@ func (s StructInfo) NumFields() int {
|
|||
// works fine.
|
||||
var tagInfoCache = &sync.Map{}
|
||||
|
||||
// TagInfoCache implements the ksql.TagInfoCache interface
|
||||
// this abstraction was created for allowing the use of
|
||||
// mocks during tests.
|
||||
type TagInfoCache struct{}
|
||||
|
||||
// GetTagInfo efficiently returns the type information
|
||||
// using a global private cache
|
||||
//
|
||||
// In the future we might move this cache inside
|
||||
// a struct, but for now this accessor is the one
|
||||
// we are using
|
||||
func (t TagInfoCache) GetTagInfo(key reflect.Type) (StructInfo, error) {
|
||||
return getCachedTagInfo(tagInfoCache, key)
|
||||
}
|
||||
|
||||
// GetTagInfo is the static version of the method above
|
||||
// created for convenience when the extra abstraction
|
||||
// is unnecessary.
|
||||
func GetTagInfo(key reflect.Type) (StructInfo, error) {
|
||||
return getCachedTagInfo(tagInfoCache, key)
|
||||
}
|
||||
|
|
52
ksql.go
52
ksql.go
|
@ -33,6 +33,8 @@ type DB struct {
|
|||
driver string
|
||||
dialect Dialect
|
||||
db DBAdapter
|
||||
|
||||
tagInfoCache TagInfoCache
|
||||
}
|
||||
|
||||
// DBAdapter is minimalistic interface to decouple our implementation
|
||||
|
@ -45,6 +47,10 @@ type DBAdapter interface {
|
|||
QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error)
|
||||
}
|
||||
|
||||
type TagInfoCache interface {
|
||||
GetTagInfo(key reflect.Type) (structs.StructInfo, error)
|
||||
}
|
||||
|
||||
// TxBeginner needs to be implemented by the DBAdapter in order to make it possible
|
||||
// to use the `ksql.Transaction()` function.
|
||||
type TxBeginner interface {
|
||||
|
@ -107,6 +113,8 @@ func NewWithAdapter(
|
|||
dialect: dialect,
|
||||
driver: dialectName,
|
||||
db: db,
|
||||
|
||||
tagInfoCache: structs.TagInfoCache{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -142,7 +150,7 @@ func (c DB) Query(
|
|||
slice = slice.Slice(0, 0)
|
||||
}
|
||||
|
||||
info, err := structs.GetTagInfo(structType)
|
||||
info, err := c.tagInfoCache.GetTagInfo(structType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -154,7 +162,7 @@ func (c DB) Query(
|
|||
}
|
||||
|
||||
if firstToken == "FROM" {
|
||||
selectPrefix, err := buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()])
|
||||
selectPrefix, err := c.buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -185,7 +193,7 @@ func (c DB) Query(
|
|||
elemPtr = elemPtr.Elem()
|
||||
}
|
||||
|
||||
err = scanRows(c.dialect, rows, elemPtr.Interface())
|
||||
err = c.scanRows(c.dialect, rows, elemPtr.Interface())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -232,7 +240,7 @@ func (c DB) QueryOne(
|
|||
return fmt.Errorf("ksql: expected to receive a pointer to struct, but got: %T", record)
|
||||
}
|
||||
|
||||
info, err := structs.GetTagInfo(tStruct)
|
||||
info, err := c.tagInfoCache.GetTagInfo(tStruct)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -244,7 +252,7 @@ func (c DB) QueryOne(
|
|||
}
|
||||
|
||||
if firstToken == "FROM" {
|
||||
selectPrefix, err := buildSelectQuery(c.dialect, tStruct, info, selectQueryCache[c.dialect.DriverName()])
|
||||
selectPrefix, err := c.buildSelectQuery(c.dialect, tStruct, info, selectQueryCache[c.dialect.DriverName()])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -264,7 +272,7 @@ func (c DB) QueryOne(
|
|||
return ErrRecordNotFound
|
||||
}
|
||||
|
||||
err = scanRowsFromType(c.dialect, rows, record, t, v)
|
||||
err = c.scanRowsFromType(c.dialect, rows, record, t, v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -305,7 +313,7 @@ func (c DB) QueryChunks(
|
|||
return err
|
||||
}
|
||||
|
||||
info, err := structs.GetTagInfo(structType)
|
||||
info, err := c.tagInfoCache.GetTagInfo(structType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -317,7 +325,7 @@ func (c DB) QueryChunks(
|
|||
}
|
||||
|
||||
if firstToken == "FROM" {
|
||||
selectPrefix, err := buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()])
|
||||
selectPrefix, err := c.buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -343,7 +351,7 @@ func (c DB) QueryChunks(
|
|||
chunk = reflect.Append(chunk, elemValue)
|
||||
}
|
||||
|
||||
err = scanRows(c.dialect, rows, chunk.Index(idx).Addr().Interface())
|
||||
err = c.scanRows(c.dialect, rows, chunk.Index(idx).Addr().Interface())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -415,7 +423,7 @@ func (c DB) Insert(
|
|||
return fmt.Errorf("can't insert in ksql.Table: %s", err)
|
||||
}
|
||||
|
||||
info, err := structs.GetTagInfo(t.Elem())
|
||||
info, err := c.tagInfoCache.GetTagInfo(t.Elem())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -652,7 +660,7 @@ func (c DB) Patch(
|
|||
}
|
||||
tStruct = t.Elem()
|
||||
}
|
||||
info, err := structs.GetTagInfo(tStruct)
|
||||
info, err := c.tagInfoCache.GetTagInfo(tStruct)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -929,13 +937,13 @@ func (nopScanner) Scan(value interface{}) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func scanRows(dialect Dialect, rows Rows, record interface{}) error {
|
||||
func (c DB) scanRows(dialect Dialect, rows Rows, record interface{}) error {
|
||||
v := reflect.ValueOf(record)
|
||||
t := v.Type()
|
||||
return scanRowsFromType(dialect, rows, record, t, v)
|
||||
return c.scanRowsFromType(dialect, rows, record, t, v)
|
||||
}
|
||||
|
||||
func scanRowsFromType(
|
||||
func (c DB) scanRowsFromType(
|
||||
dialect Dialect,
|
||||
rows Rows,
|
||||
record interface{},
|
||||
|
@ -953,7 +961,7 @@ func scanRowsFromType(
|
|||
return fmt.Errorf("ksql: expected record to be a pointer to struct, but got: %T", record)
|
||||
}
|
||||
|
||||
info, err := structs.GetTagInfo(t)
|
||||
info, err := c.tagInfoCache.GetTagInfo(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -963,7 +971,7 @@ func scanRowsFromType(
|
|||
// This version is positional meaning that it expect the arguments
|
||||
// to follow an specific order. It's ok because we don't allow the
|
||||
// user to type the "SELECT" part of the query for nested structs.
|
||||
scanArgs, err = getScanArgsForNestedStructs(dialect, rows, t, v, info)
|
||||
scanArgs, err = c.getScanArgsForNestedStructs(dialect, rows, t, v, info)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -980,7 +988,7 @@ func scanRowsFromType(
|
|||
return rows.Scan(scanArgs...)
|
||||
}
|
||||
|
||||
func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v reflect.Value, info structs.StructInfo) ([]interface{}, error) {
|
||||
func (c DB) getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v reflect.Value, info structs.StructInfo) ([]interface{}, error) {
|
||||
scanArgs := []interface{}{}
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
if !info.ByIndex(i).Valid {
|
||||
|
@ -988,7 +996,7 @@ func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v r
|
|||
}
|
||||
|
||||
// TODO(vingarcia00): Handle case where type is pointer
|
||||
nestedStructInfo, err := structs.GetTagInfo(t.Field(i).Type)
|
||||
nestedStructInfo, err := c.tagInfoCache.GetTagInfo(t.Field(i).Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -1076,7 +1084,7 @@ func getFirstToken(s string) string {
|
|||
return token.String()
|
||||
}
|
||||
|
||||
func buildSelectQuery(
|
||||
func (c DB) buildSelectQuery(
|
||||
dialect Dialect,
|
||||
structType reflect.Type,
|
||||
info structs.StructInfo,
|
||||
|
@ -1091,7 +1099,7 @@ func buildSelectQuery(
|
|||
}
|
||||
|
||||
if info.IsNestedStruct {
|
||||
query, err = buildSelectQueryForNestedStructs(dialect, structType, info)
|
||||
query, err = c.buildSelectQueryForNestedStructs(dialect, structType, info)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -1121,7 +1129,7 @@ func buildSelectQueryForPlainStructs(
|
|||
return "SELECT " + strings.Join(fields, ", ") + " "
|
||||
}
|
||||
|
||||
func buildSelectQueryForNestedStructs(
|
||||
func (c DB) buildSelectQueryForNestedStructs(
|
||||
dialect Dialect,
|
||||
structType reflect.Type,
|
||||
info structs.StructInfo,
|
||||
|
@ -1142,7 +1150,7 @@ func buildSelectQueryForNestedStructs(
|
|||
)
|
||||
}
|
||||
|
||||
nestedStructTagInfo, err := structs.GetTagInfo(nestedStructType)
|
||||
nestedStructTagInfo, err := c.tagInfoCache.GetTagInfo(nestedStructType)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/vingarcia/ksql/internal/structs"
|
||||
tt "github.com/vingarcia/ksql/internal/testtools"
|
||||
"github.com/vingarcia/ksql/nullable"
|
||||
)
|
||||
|
@ -2455,7 +2456,7 @@ func ScanRowsTest(
|
|||
tt.AssertEqual(t, rows.Next(), true)
|
||||
|
||||
var u user
|
||||
err = scanRows(dialect, rows, &u)
|
||||
err = c.scanRows(dialect, rows, &u)
|
||||
tt.AssertNoErr(t, err)
|
||||
|
||||
tt.AssertEqual(t, u.Name, "User2")
|
||||
|
@ -2488,7 +2489,7 @@ func ScanRowsTest(
|
|||
// Omitted for testing purposes:
|
||||
// Name string `ksql:"name"`
|
||||
}
|
||||
err = scanRows(dialect, rows, &u)
|
||||
err = c.scanRows(dialect, rows, &u)
|
||||
tt.AssertNoErr(t, err)
|
||||
|
||||
tt.AssertEqual(t, u.Age, 22)
|
||||
|
@ -2504,6 +2505,7 @@ func ScanRowsTest(
|
|||
ctx := context.TODO()
|
||||
db, closer := newDBAdapter(t)
|
||||
defer closer.Close()
|
||||
c := newTestDB(db, driver)
|
||||
|
||||
rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User2'")
|
||||
tt.AssertNoErr(t, err)
|
||||
|
@ -2511,7 +2513,7 @@ func ScanRowsTest(
|
|||
var u user
|
||||
err = rows.Close()
|
||||
tt.AssertNoErr(t, err)
|
||||
err = scanRows(dialect, rows, &u)
|
||||
err = c.scanRows(dialect, rows, &u)
|
||||
tt.AssertNotEqual(t, err, nil)
|
||||
})
|
||||
|
||||
|
@ -2525,13 +2527,14 @@ func ScanRowsTest(
|
|||
ctx := context.TODO()
|
||||
db, closer := newDBAdapter(t)
|
||||
defer closer.Close()
|
||||
c := newTestDB(db, driver)
|
||||
|
||||
rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User2'")
|
||||
tt.AssertNoErr(t, err)
|
||||
defer rows.Close()
|
||||
|
||||
var u user
|
||||
err = scanRows(dialect, rows, u)
|
||||
err = c.scanRows(dialect, rows, u)
|
||||
tt.AssertErrContains(t, err, "ksql", "expected", "pointer to struct", "user")
|
||||
})
|
||||
|
||||
|
@ -2545,13 +2548,14 @@ func ScanRowsTest(
|
|||
ctx := context.TODO()
|
||||
db, closer := newDBAdapter(t)
|
||||
defer closer.Close()
|
||||
c := newTestDB(db, driver)
|
||||
|
||||
rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User2'")
|
||||
tt.AssertNoErr(t, err)
|
||||
defer rows.Close()
|
||||
|
||||
var u map[string]interface{}
|
||||
err = scanRows(dialect, rows, &u)
|
||||
err = c.scanRows(dialect, rows, &u)
|
||||
tt.AssertErrContains(t, err, "ksql", "expected", "pointer to struct", "map[string]interface")
|
||||
})
|
||||
})
|
||||
|
@ -2684,6 +2688,8 @@ func newTestDB(db DBAdapter, driver string) DB {
|
|||
driver: driver,
|
||||
dialect: supportedDialects[driver],
|
||||
db: db,
|
||||
|
||||
tagInfoCache: structs.TagInfoCache{},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue