Decouple ksql.DB from TagInfoCache so we can replace it during tests

pull/29/head
Vinícius Garcia 2022-08-24 22:45:24 -03:00
parent 0e95506343
commit 74cb87bea0
3 changed files with 53 additions and 27 deletions

View File

@ -69,12 +69,24 @@ func (s StructInfo) NumFields() int {
// works fine. // works fine.
var tagInfoCache = &sync.Map{} var tagInfoCache = &sync.Map{}
// TagInfoCache implements the ksql.TagInfoCache interface
// this abstraction was created for allowing the use of
// mocks during tests.
type TagInfoCache struct{}
// GetTagInfo efficiently returns the type information // GetTagInfo efficiently returns the type information
// using a global private cache // using a global private cache
// //
// In the future we might move this cache inside // In the future we might move this cache inside
// a struct, but for now this accessor is the one // a struct, but for now this accessor is the one
// we are using // we are using
func (t TagInfoCache) GetTagInfo(key reflect.Type) (StructInfo, error) {
return getCachedTagInfo(tagInfoCache, key)
}
// GetTagInfo is the static version of the method above
// created for convenience when the extra abstraction
// is unnecessary.
func GetTagInfo(key reflect.Type) (StructInfo, error) { func GetTagInfo(key reflect.Type) (StructInfo, error) {
return getCachedTagInfo(tagInfoCache, key) return getCachedTagInfo(tagInfoCache, key)
} }

52
ksql.go
View File

@ -33,6 +33,8 @@ type DB struct {
driver string driver string
dialect Dialect dialect Dialect
db DBAdapter db DBAdapter
tagInfoCache TagInfoCache
} }
// DBAdapter is minimalistic interface to decouple our implementation // DBAdapter is minimalistic interface to decouple our implementation
@ -45,6 +47,10 @@ type DBAdapter interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error)
} }
type TagInfoCache interface {
GetTagInfo(key reflect.Type) (structs.StructInfo, error)
}
// TxBeginner needs to be implemented by the DBAdapter in order to make it possible // TxBeginner needs to be implemented by the DBAdapter in order to make it possible
// to use the `ksql.Transaction()` function. // to use the `ksql.Transaction()` function.
type TxBeginner interface { type TxBeginner interface {
@ -107,6 +113,8 @@ func NewWithAdapter(
dialect: dialect, dialect: dialect,
driver: dialectName, driver: dialectName,
db: db, db: db,
tagInfoCache: structs.TagInfoCache{},
}, nil }, nil
} }
@ -142,7 +150,7 @@ func (c DB) Query(
slice = slice.Slice(0, 0) slice = slice.Slice(0, 0)
} }
info, err := structs.GetTagInfo(structType) info, err := c.tagInfoCache.GetTagInfo(structType)
if err != nil { if err != nil {
return err return err
} }
@ -154,7 +162,7 @@ func (c DB) Query(
} }
if firstToken == "FROM" { if firstToken == "FROM" {
selectPrefix, err := buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()]) selectPrefix, err := c.buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()])
if err != nil { if err != nil {
return err return err
} }
@ -185,7 +193,7 @@ func (c DB) Query(
elemPtr = elemPtr.Elem() elemPtr = elemPtr.Elem()
} }
err = scanRows(c.dialect, rows, elemPtr.Interface()) err = c.scanRows(c.dialect, rows, elemPtr.Interface())
if err != nil { if err != nil {
return err return err
} }
@ -232,7 +240,7 @@ func (c DB) QueryOne(
return fmt.Errorf("ksql: expected to receive a pointer to struct, but got: %T", record) return fmt.Errorf("ksql: expected to receive a pointer to struct, but got: %T", record)
} }
info, err := structs.GetTagInfo(tStruct) info, err := c.tagInfoCache.GetTagInfo(tStruct)
if err != nil { if err != nil {
return err return err
} }
@ -244,7 +252,7 @@ func (c DB) QueryOne(
} }
if firstToken == "FROM" { if firstToken == "FROM" {
selectPrefix, err := buildSelectQuery(c.dialect, tStruct, info, selectQueryCache[c.dialect.DriverName()]) selectPrefix, err := c.buildSelectQuery(c.dialect, tStruct, info, selectQueryCache[c.dialect.DriverName()])
if err != nil { if err != nil {
return err return err
} }
@ -264,7 +272,7 @@ func (c DB) QueryOne(
return ErrRecordNotFound return ErrRecordNotFound
} }
err = scanRowsFromType(c.dialect, rows, record, t, v) err = c.scanRowsFromType(c.dialect, rows, record, t, v)
if err != nil { if err != nil {
return err return err
} }
@ -305,7 +313,7 @@ func (c DB) QueryChunks(
return err return err
} }
info, err := structs.GetTagInfo(structType) info, err := c.tagInfoCache.GetTagInfo(structType)
if err != nil { if err != nil {
return err return err
} }
@ -317,7 +325,7 @@ func (c DB) QueryChunks(
} }
if firstToken == "FROM" { if firstToken == "FROM" {
selectPrefix, err := buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()]) selectPrefix, err := c.buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()])
if err != nil { if err != nil {
return err return err
} }
@ -343,7 +351,7 @@ func (c DB) QueryChunks(
chunk = reflect.Append(chunk, elemValue) chunk = reflect.Append(chunk, elemValue)
} }
err = scanRows(c.dialect, rows, chunk.Index(idx).Addr().Interface()) err = c.scanRows(c.dialect, rows, chunk.Index(idx).Addr().Interface())
if err != nil { if err != nil {
return err return err
} }
@ -415,7 +423,7 @@ func (c DB) Insert(
return fmt.Errorf("can't insert in ksql.Table: %s", err) return fmt.Errorf("can't insert in ksql.Table: %s", err)
} }
info, err := structs.GetTagInfo(t.Elem()) info, err := c.tagInfoCache.GetTagInfo(t.Elem())
if err != nil { if err != nil {
return err return err
} }
@ -652,7 +660,7 @@ func (c DB) Patch(
} }
tStruct = t.Elem() tStruct = t.Elem()
} }
info, err := structs.GetTagInfo(tStruct) info, err := c.tagInfoCache.GetTagInfo(tStruct)
if err != nil { if err != nil {
return err return err
} }
@ -929,13 +937,13 @@ func (nopScanner) Scan(value interface{}) error {
return nil return nil
} }
func scanRows(dialect Dialect, rows Rows, record interface{}) error { func (c DB) scanRows(dialect Dialect, rows Rows, record interface{}) error {
v := reflect.ValueOf(record) v := reflect.ValueOf(record)
t := v.Type() t := v.Type()
return scanRowsFromType(dialect, rows, record, t, v) return c.scanRowsFromType(dialect, rows, record, t, v)
} }
func scanRowsFromType( func (c DB) scanRowsFromType(
dialect Dialect, dialect Dialect,
rows Rows, rows Rows,
record interface{}, record interface{},
@ -953,7 +961,7 @@ func scanRowsFromType(
return fmt.Errorf("ksql: expected record to be a pointer to struct, but got: %T", record) return fmt.Errorf("ksql: expected record to be a pointer to struct, but got: %T", record)
} }
info, err := structs.GetTagInfo(t) info, err := c.tagInfoCache.GetTagInfo(t)
if err != nil { if err != nil {
return err return err
} }
@ -963,7 +971,7 @@ func scanRowsFromType(
// This version is positional meaning that it expect the arguments // This version is positional meaning that it expect the arguments
// to follow an specific order. It's ok because we don't allow the // to follow an specific order. It's ok because we don't allow the
// user to type the "SELECT" part of the query for nested structs. // user to type the "SELECT" part of the query for nested structs.
scanArgs, err = getScanArgsForNestedStructs(dialect, rows, t, v, info) scanArgs, err = c.getScanArgsForNestedStructs(dialect, rows, t, v, info)
if err != nil { if err != nil {
return err return err
} }
@ -980,7 +988,7 @@ func scanRowsFromType(
return rows.Scan(scanArgs...) return rows.Scan(scanArgs...)
} }
func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v reflect.Value, info structs.StructInfo) ([]interface{}, error) { func (c DB) getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v reflect.Value, info structs.StructInfo) ([]interface{}, error) {
scanArgs := []interface{}{} scanArgs := []interface{}{}
for i := 0; i < v.NumField(); i++ { for i := 0; i < v.NumField(); i++ {
if !info.ByIndex(i).Valid { if !info.ByIndex(i).Valid {
@ -988,7 +996,7 @@ func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v r
} }
// TODO(vingarcia00): Handle case where type is pointer // TODO(vingarcia00): Handle case where type is pointer
nestedStructInfo, err := structs.GetTagInfo(t.Field(i).Type) nestedStructInfo, err := c.tagInfoCache.GetTagInfo(t.Field(i).Type)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1076,7 +1084,7 @@ func getFirstToken(s string) string {
return token.String() return token.String()
} }
func buildSelectQuery( func (c DB) buildSelectQuery(
dialect Dialect, dialect Dialect,
structType reflect.Type, structType reflect.Type,
info structs.StructInfo, info structs.StructInfo,
@ -1091,7 +1099,7 @@ func buildSelectQuery(
} }
if info.IsNestedStruct { if info.IsNestedStruct {
query, err = buildSelectQueryForNestedStructs(dialect, structType, info) query, err = c.buildSelectQueryForNestedStructs(dialect, structType, info)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -1121,7 +1129,7 @@ func buildSelectQueryForPlainStructs(
return "SELECT " + strings.Join(fields, ", ") + " " return "SELECT " + strings.Join(fields, ", ") + " "
} }
func buildSelectQueryForNestedStructs( func (c DB) buildSelectQueryForNestedStructs(
dialect Dialect, dialect Dialect,
structType reflect.Type, structType reflect.Type,
info structs.StructInfo, info structs.StructInfo,
@ -1142,7 +1150,7 @@ func buildSelectQueryForNestedStructs(
) )
} }
nestedStructTagInfo, err := structs.GetTagInfo(nestedStructType) nestedStructTagInfo, err := c.tagInfoCache.GetTagInfo(nestedStructType)
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@ -9,6 +9,7 @@ import (
"testing" "testing"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/vingarcia/ksql/internal/structs"
tt "github.com/vingarcia/ksql/internal/testtools" tt "github.com/vingarcia/ksql/internal/testtools"
"github.com/vingarcia/ksql/nullable" "github.com/vingarcia/ksql/nullable"
) )
@ -2455,7 +2456,7 @@ func ScanRowsTest(
tt.AssertEqual(t, rows.Next(), true) tt.AssertEqual(t, rows.Next(), true)
var u user var u user
err = scanRows(dialect, rows, &u) err = c.scanRows(dialect, rows, &u)
tt.AssertNoErr(t, err) tt.AssertNoErr(t, err)
tt.AssertEqual(t, u.Name, "User2") tt.AssertEqual(t, u.Name, "User2")
@ -2488,7 +2489,7 @@ func ScanRowsTest(
// Omitted for testing purposes: // Omitted for testing purposes:
// Name string `ksql:"name"` // Name string `ksql:"name"`
} }
err = scanRows(dialect, rows, &u) err = c.scanRows(dialect, rows, &u)
tt.AssertNoErr(t, err) tt.AssertNoErr(t, err)
tt.AssertEqual(t, u.Age, 22) tt.AssertEqual(t, u.Age, 22)
@ -2504,6 +2505,7 @@ func ScanRowsTest(
ctx := context.TODO() ctx := context.TODO()
db, closer := newDBAdapter(t) db, closer := newDBAdapter(t)
defer closer.Close() defer closer.Close()
c := newTestDB(db, driver)
rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User2'") rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User2'")
tt.AssertNoErr(t, err) tt.AssertNoErr(t, err)
@ -2511,7 +2513,7 @@ func ScanRowsTest(
var u user var u user
err = rows.Close() err = rows.Close()
tt.AssertNoErr(t, err) tt.AssertNoErr(t, err)
err = scanRows(dialect, rows, &u) err = c.scanRows(dialect, rows, &u)
tt.AssertNotEqual(t, err, nil) tt.AssertNotEqual(t, err, nil)
}) })
@ -2525,13 +2527,14 @@ func ScanRowsTest(
ctx := context.TODO() ctx := context.TODO()
db, closer := newDBAdapter(t) db, closer := newDBAdapter(t)
defer closer.Close() defer closer.Close()
c := newTestDB(db, driver)
rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User2'") rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User2'")
tt.AssertNoErr(t, err) tt.AssertNoErr(t, err)
defer rows.Close() defer rows.Close()
var u user var u user
err = scanRows(dialect, rows, u) err = c.scanRows(dialect, rows, u)
tt.AssertErrContains(t, err, "ksql", "expected", "pointer to struct", "user") tt.AssertErrContains(t, err, "ksql", "expected", "pointer to struct", "user")
}) })
@ -2545,13 +2548,14 @@ func ScanRowsTest(
ctx := context.TODO() ctx := context.TODO()
db, closer := newDBAdapter(t) db, closer := newDBAdapter(t)
defer closer.Close() defer closer.Close()
c := newTestDB(db, driver)
rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User2'") rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User2'")
tt.AssertNoErr(t, err) tt.AssertNoErr(t, err)
defer rows.Close() defer rows.Close()
var u map[string]interface{} var u map[string]interface{}
err = scanRows(dialect, rows, &u) err = c.scanRows(dialect, rows, &u)
tt.AssertErrContains(t, err, "ksql", "expected", "pointer to struct", "map[string]interface") tt.AssertErrContains(t, err, "ksql", "expected", "pointer to struct", "map[string]interface")
}) })
}) })
@ -2684,6 +2688,8 @@ func newTestDB(db DBAdapter, driver string) DB {
driver: driver, driver: driver,
dialect: supportedDialects[driver], dialect: supportedDialects[driver],
db: db, db: db,
tagInfoCache: structs.TagInfoCache{},
} }
} }