Decouple ksql.DB from TagInfoCache so we can replace it during tests

pull/29/head
Vinícius Garcia 2022-08-24 22:45:24 -03:00
parent 0e95506343
commit 74cb87bea0
3 changed files with 53 additions and 27 deletions

View File

@ -69,12 +69,24 @@ func (s StructInfo) NumFields() int {
// works fine.
var tagInfoCache = &sync.Map{}
// TagInfoCache implements the ksql.TagInfoCache interface
// this abstraction was created for allowing the use of
// mocks during tests.
type TagInfoCache struct{}
// GetTagInfo efficiently returns the type information
// using a global private cache
//
// In the future we might move this cache inside
// a struct, but for now this accessor is the one
// we are using
func (t TagInfoCache) GetTagInfo(key reflect.Type) (StructInfo, error) {
return getCachedTagInfo(tagInfoCache, key)
}
// GetTagInfo is the static version of the method above
// created for convenience when the extra abstraction
// is unnecessary.
func GetTagInfo(key reflect.Type) (StructInfo, error) {
return getCachedTagInfo(tagInfoCache, key)
}

52
ksql.go
View File

@ -33,6 +33,8 @@ type DB struct {
driver string
dialect Dialect
db DBAdapter
tagInfoCache TagInfoCache
}
// DBAdapter is minimalistic interface to decouple our implementation
@ -45,6 +47,10 @@ type DBAdapter interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error)
}
type TagInfoCache interface {
GetTagInfo(key reflect.Type) (structs.StructInfo, error)
}
// TxBeginner needs to be implemented by the DBAdapter in order to make it possible
// to use the `ksql.Transaction()` function.
type TxBeginner interface {
@ -107,6 +113,8 @@ func NewWithAdapter(
dialect: dialect,
driver: dialectName,
db: db,
tagInfoCache: structs.TagInfoCache{},
}, nil
}
@ -142,7 +150,7 @@ func (c DB) Query(
slice = slice.Slice(0, 0)
}
info, err := structs.GetTagInfo(structType)
info, err := c.tagInfoCache.GetTagInfo(structType)
if err != nil {
return err
}
@ -154,7 +162,7 @@ func (c DB) Query(
}
if firstToken == "FROM" {
selectPrefix, err := buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()])
selectPrefix, err := c.buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()])
if err != nil {
return err
}
@ -185,7 +193,7 @@ func (c DB) Query(
elemPtr = elemPtr.Elem()
}
err = scanRows(c.dialect, rows, elemPtr.Interface())
err = c.scanRows(c.dialect, rows, elemPtr.Interface())
if err != nil {
return err
}
@ -232,7 +240,7 @@ func (c DB) QueryOne(
return fmt.Errorf("ksql: expected to receive a pointer to struct, but got: %T", record)
}
info, err := structs.GetTagInfo(tStruct)
info, err := c.tagInfoCache.GetTagInfo(tStruct)
if err != nil {
return err
}
@ -244,7 +252,7 @@ func (c DB) QueryOne(
}
if firstToken == "FROM" {
selectPrefix, err := buildSelectQuery(c.dialect, tStruct, info, selectQueryCache[c.dialect.DriverName()])
selectPrefix, err := c.buildSelectQuery(c.dialect, tStruct, info, selectQueryCache[c.dialect.DriverName()])
if err != nil {
return err
}
@ -264,7 +272,7 @@ func (c DB) QueryOne(
return ErrRecordNotFound
}
err = scanRowsFromType(c.dialect, rows, record, t, v)
err = c.scanRowsFromType(c.dialect, rows, record, t, v)
if err != nil {
return err
}
@ -305,7 +313,7 @@ func (c DB) QueryChunks(
return err
}
info, err := structs.GetTagInfo(structType)
info, err := c.tagInfoCache.GetTagInfo(structType)
if err != nil {
return err
}
@ -317,7 +325,7 @@ func (c DB) QueryChunks(
}
if firstToken == "FROM" {
selectPrefix, err := buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()])
selectPrefix, err := c.buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()])
if err != nil {
return err
}
@ -343,7 +351,7 @@ func (c DB) QueryChunks(
chunk = reflect.Append(chunk, elemValue)
}
err = scanRows(c.dialect, rows, chunk.Index(idx).Addr().Interface())
err = c.scanRows(c.dialect, rows, chunk.Index(idx).Addr().Interface())
if err != nil {
return err
}
@ -415,7 +423,7 @@ func (c DB) Insert(
return fmt.Errorf("can't insert in ksql.Table: %s", err)
}
info, err := structs.GetTagInfo(t.Elem())
info, err := c.tagInfoCache.GetTagInfo(t.Elem())
if err != nil {
return err
}
@ -652,7 +660,7 @@ func (c DB) Patch(
}
tStruct = t.Elem()
}
info, err := structs.GetTagInfo(tStruct)
info, err := c.tagInfoCache.GetTagInfo(tStruct)
if err != nil {
return err
}
@ -929,13 +937,13 @@ func (nopScanner) Scan(value interface{}) error {
return nil
}
func scanRows(dialect Dialect, rows Rows, record interface{}) error {
func (c DB) scanRows(dialect Dialect, rows Rows, record interface{}) error {
v := reflect.ValueOf(record)
t := v.Type()
return scanRowsFromType(dialect, rows, record, t, v)
return c.scanRowsFromType(dialect, rows, record, t, v)
}
func scanRowsFromType(
func (c DB) scanRowsFromType(
dialect Dialect,
rows Rows,
record interface{},
@ -953,7 +961,7 @@ func scanRowsFromType(
return fmt.Errorf("ksql: expected record to be a pointer to struct, but got: %T", record)
}
info, err := structs.GetTagInfo(t)
info, err := c.tagInfoCache.GetTagInfo(t)
if err != nil {
return err
}
@ -963,7 +971,7 @@ func scanRowsFromType(
// This version is positional meaning that it expect the arguments
// to follow an specific order. It's ok because we don't allow the
// user to type the "SELECT" part of the query for nested structs.
scanArgs, err = getScanArgsForNestedStructs(dialect, rows, t, v, info)
scanArgs, err = c.getScanArgsForNestedStructs(dialect, rows, t, v, info)
if err != nil {
return err
}
@ -980,7 +988,7 @@ func scanRowsFromType(
return rows.Scan(scanArgs...)
}
func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v reflect.Value, info structs.StructInfo) ([]interface{}, error) {
func (c DB) getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v reflect.Value, info structs.StructInfo) ([]interface{}, error) {
scanArgs := []interface{}{}
for i := 0; i < v.NumField(); i++ {
if !info.ByIndex(i).Valid {
@ -988,7 +996,7 @@ func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v r
}
// TODO(vingarcia00): Handle case where type is pointer
nestedStructInfo, err := structs.GetTagInfo(t.Field(i).Type)
nestedStructInfo, err := c.tagInfoCache.GetTagInfo(t.Field(i).Type)
if err != nil {
return nil, err
}
@ -1076,7 +1084,7 @@ func getFirstToken(s string) string {
return token.String()
}
func buildSelectQuery(
func (c DB) buildSelectQuery(
dialect Dialect,
structType reflect.Type,
info structs.StructInfo,
@ -1091,7 +1099,7 @@ func buildSelectQuery(
}
if info.IsNestedStruct {
query, err = buildSelectQueryForNestedStructs(dialect, structType, info)
query, err = c.buildSelectQueryForNestedStructs(dialect, structType, info)
if err != nil {
return "", err
}
@ -1121,7 +1129,7 @@ func buildSelectQueryForPlainStructs(
return "SELECT " + strings.Join(fields, ", ") + " "
}
func buildSelectQueryForNestedStructs(
func (c DB) buildSelectQueryForNestedStructs(
dialect Dialect,
structType reflect.Type,
info structs.StructInfo,
@ -1142,7 +1150,7 @@ func buildSelectQueryForNestedStructs(
)
}
nestedStructTagInfo, err := structs.GetTagInfo(nestedStructType)
nestedStructTagInfo, err := c.tagInfoCache.GetTagInfo(nestedStructType)
if err != nil {
return "", err
}

View File

@ -9,6 +9,7 @@ import (
"testing"
"github.com/pkg/errors"
"github.com/vingarcia/ksql/internal/structs"
tt "github.com/vingarcia/ksql/internal/testtools"
"github.com/vingarcia/ksql/nullable"
)
@ -2455,7 +2456,7 @@ func ScanRowsTest(
tt.AssertEqual(t, rows.Next(), true)
var u user
err = scanRows(dialect, rows, &u)
err = c.scanRows(dialect, rows, &u)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, u.Name, "User2")
@ -2488,7 +2489,7 @@ func ScanRowsTest(
// Omitted for testing purposes:
// Name string `ksql:"name"`
}
err = scanRows(dialect, rows, &u)
err = c.scanRows(dialect, rows, &u)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, u.Age, 22)
@ -2504,6 +2505,7 @@ func ScanRowsTest(
ctx := context.TODO()
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, driver)
rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User2'")
tt.AssertNoErr(t, err)
@ -2511,7 +2513,7 @@ func ScanRowsTest(
var u user
err = rows.Close()
tt.AssertNoErr(t, err)
err = scanRows(dialect, rows, &u)
err = c.scanRows(dialect, rows, &u)
tt.AssertNotEqual(t, err, nil)
})
@ -2525,13 +2527,14 @@ func ScanRowsTest(
ctx := context.TODO()
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, driver)
rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User2'")
tt.AssertNoErr(t, err)
defer rows.Close()
var u user
err = scanRows(dialect, rows, u)
err = c.scanRows(dialect, rows, u)
tt.AssertErrContains(t, err, "ksql", "expected", "pointer to struct", "user")
})
@ -2545,13 +2548,14 @@ func ScanRowsTest(
ctx := context.TODO()
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, driver)
rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User2'")
tt.AssertNoErr(t, err)
defer rows.Close()
var u map[string]interface{}
err = scanRows(dialect, rows, &u)
err = c.scanRows(dialect, rows, &u)
tt.AssertErrContains(t, err, "ksql", "expected", "pointer to struct", "map[string]interface")
})
})
@ -2684,6 +2688,8 @@ func newTestDB(db DBAdapter, driver string) DB {
driver: driver,
dialect: supportedDialects[driver],
db: db,
tagInfoCache: structs.TagInfoCache{},
}
}