mirror of https://github.com/VinGarcia/ksql.git
Revert "Decouple ksql.DB from TagInfoCache so we can replace it during tests"
This reverts commit 74cb87bea0
.
This was done because I noticed this first commit was unnecessary.
This original commit was written in order to allow tests where
the cache would return errors, but I noticed there is a way
of provoking these errors without the need of an extra layer
of abstraction.
Thus, in order to keep the code simpler and also avoid an extra
level of indirection I am undoing this change.
pull/29/head
parent
4b37adc905
commit
8620600d01
|
@ -69,24 +69,12 @@ func (s StructInfo) NumFields() int {
|
|||
// works fine.
|
||||
var tagInfoCache = &sync.Map{}
|
||||
|
||||
// TagInfoCache implements the ksql.TagInfoCache interface
|
||||
// this abstraction was created for allowing the use of
|
||||
// mocks during tests.
|
||||
type TagInfoCache struct{}
|
||||
|
||||
// GetTagInfo efficiently returns the type information
|
||||
// using a global private cache
|
||||
//
|
||||
// In the future we might move this cache inside
|
||||
// a struct, but for now this accessor is the one
|
||||
// we are using
|
||||
func (t TagInfoCache) GetTagInfo(key reflect.Type) (StructInfo, error) {
|
||||
return getCachedTagInfo(tagInfoCache, key)
|
||||
}
|
||||
|
||||
// GetTagInfo is the static version of the method above
|
||||
// created for convenience when the extra abstraction
|
||||
// is unnecessary.
|
||||
func GetTagInfo(key reflect.Type) (StructInfo, error) {
|
||||
return getCachedTagInfo(tagInfoCache, key)
|
||||
}
|
||||
|
|
52
ksql.go
52
ksql.go
|
@ -33,8 +33,6 @@ type DB struct {
|
|||
driver string
|
||||
dialect Dialect
|
||||
db DBAdapter
|
||||
|
||||
tagInfoCache TagInfoCache
|
||||
}
|
||||
|
||||
// DBAdapter is minimalistic interface to decouple our implementation
|
||||
|
@ -47,10 +45,6 @@ type DBAdapter interface {
|
|||
QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error)
|
||||
}
|
||||
|
||||
type TagInfoCache interface {
|
||||
GetTagInfo(key reflect.Type) (structs.StructInfo, error)
|
||||
}
|
||||
|
||||
// TxBeginner needs to be implemented by the DBAdapter in order to make it possible
|
||||
// to use the `ksql.Transaction()` function.
|
||||
type TxBeginner interface {
|
||||
|
@ -113,8 +107,6 @@ func NewWithAdapter(
|
|||
dialect: dialect,
|
||||
driver: dialectName,
|
||||
db: db,
|
||||
|
||||
tagInfoCache: structs.TagInfoCache{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -150,7 +142,7 @@ func (c DB) Query(
|
|||
slice = slice.Slice(0, 0)
|
||||
}
|
||||
|
||||
info, err := c.tagInfoCache.GetTagInfo(structType)
|
||||
info, err := structs.GetTagInfo(structType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -162,7 +154,7 @@ func (c DB) Query(
|
|||
}
|
||||
|
||||
if firstToken == "FROM" {
|
||||
selectPrefix, err := c.buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()])
|
||||
selectPrefix, err := buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -193,7 +185,7 @@ func (c DB) Query(
|
|||
elemPtr = elemPtr.Elem()
|
||||
}
|
||||
|
||||
err = c.scanRows(c.dialect, rows, elemPtr.Interface())
|
||||
err = scanRows(c.dialect, rows, elemPtr.Interface())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -240,7 +232,7 @@ func (c DB) QueryOne(
|
|||
return fmt.Errorf("ksql: expected to receive a pointer to struct, but got: %T", record)
|
||||
}
|
||||
|
||||
info, err := c.tagInfoCache.GetTagInfo(tStruct)
|
||||
info, err := structs.GetTagInfo(tStruct)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -252,7 +244,7 @@ func (c DB) QueryOne(
|
|||
}
|
||||
|
||||
if firstToken == "FROM" {
|
||||
selectPrefix, err := c.buildSelectQuery(c.dialect, tStruct, info, selectQueryCache[c.dialect.DriverName()])
|
||||
selectPrefix, err := buildSelectQuery(c.dialect, tStruct, info, selectQueryCache[c.dialect.DriverName()])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -272,7 +264,7 @@ func (c DB) QueryOne(
|
|||
return ErrRecordNotFound
|
||||
}
|
||||
|
||||
err = c.scanRowsFromType(c.dialect, rows, record, t, v)
|
||||
err = scanRowsFromType(c.dialect, rows, record, t, v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -313,7 +305,7 @@ func (c DB) QueryChunks(
|
|||
return err
|
||||
}
|
||||
|
||||
info, err := c.tagInfoCache.GetTagInfo(structType)
|
||||
info, err := structs.GetTagInfo(structType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -325,7 +317,7 @@ func (c DB) QueryChunks(
|
|||
}
|
||||
|
||||
if firstToken == "FROM" {
|
||||
selectPrefix, err := c.buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()])
|
||||
selectPrefix, err := buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -351,7 +343,7 @@ func (c DB) QueryChunks(
|
|||
chunk = reflect.Append(chunk, elemValue)
|
||||
}
|
||||
|
||||
err = c.scanRows(c.dialect, rows, chunk.Index(idx).Addr().Interface())
|
||||
err = scanRows(c.dialect, rows, chunk.Index(idx).Addr().Interface())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -423,7 +415,7 @@ func (c DB) Insert(
|
|||
return fmt.Errorf("can't insert in ksql.Table: %s", err)
|
||||
}
|
||||
|
||||
info, err := c.tagInfoCache.GetTagInfo(t.Elem())
|
||||
info, err := structs.GetTagInfo(t.Elem())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -660,7 +652,7 @@ func (c DB) Patch(
|
|||
}
|
||||
tStruct = t.Elem()
|
||||
}
|
||||
info, err := c.tagInfoCache.GetTagInfo(tStruct)
|
||||
info, err := structs.GetTagInfo(tStruct)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -937,13 +929,13 @@ func (nopScanner) Scan(value interface{}) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c DB) scanRows(dialect Dialect, rows Rows, record interface{}) error {
|
||||
func scanRows(dialect Dialect, rows Rows, record interface{}) error {
|
||||
v := reflect.ValueOf(record)
|
||||
t := v.Type()
|
||||
return c.scanRowsFromType(dialect, rows, record, t, v)
|
||||
return scanRowsFromType(dialect, rows, record, t, v)
|
||||
}
|
||||
|
||||
func (c DB) scanRowsFromType(
|
||||
func scanRowsFromType(
|
||||
dialect Dialect,
|
||||
rows Rows,
|
||||
record interface{},
|
||||
|
@ -961,7 +953,7 @@ func (c DB) scanRowsFromType(
|
|||
return fmt.Errorf("ksql: expected record to be a pointer to struct, but got: %T", record)
|
||||
}
|
||||
|
||||
info, err := c.tagInfoCache.GetTagInfo(t)
|
||||
info, err := structs.GetTagInfo(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -971,7 +963,7 @@ func (c DB) scanRowsFromType(
|
|||
// This version is positional meaning that it expect the arguments
|
||||
// to follow an specific order. It's ok because we don't allow the
|
||||
// user to type the "SELECT" part of the query for nested structs.
|
||||
scanArgs, err = c.getScanArgsForNestedStructs(dialect, rows, t, v, info)
|
||||
scanArgs, err = getScanArgsForNestedStructs(dialect, rows, t, v, info)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -988,7 +980,7 @@ func (c DB) scanRowsFromType(
|
|||
return rows.Scan(scanArgs...)
|
||||
}
|
||||
|
||||
func (c DB) getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v reflect.Value, info structs.StructInfo) ([]interface{}, error) {
|
||||
func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v reflect.Value, info structs.StructInfo) ([]interface{}, error) {
|
||||
scanArgs := []interface{}{}
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
if !info.ByIndex(i).Valid {
|
||||
|
@ -996,7 +988,7 @@ func (c DB) getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Ty
|
|||
}
|
||||
|
||||
// TODO(vingarcia00): Handle case where type is pointer
|
||||
nestedStructInfo, err := c.tagInfoCache.GetTagInfo(t.Field(i).Type)
|
||||
nestedStructInfo, err := structs.GetTagInfo(t.Field(i).Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -1084,7 +1076,7 @@ func getFirstToken(s string) string {
|
|||
return token.String()
|
||||
}
|
||||
|
||||
func (c DB) buildSelectQuery(
|
||||
func buildSelectQuery(
|
||||
dialect Dialect,
|
||||
structType reflect.Type,
|
||||
info structs.StructInfo,
|
||||
|
@ -1099,7 +1091,7 @@ func (c DB) buildSelectQuery(
|
|||
}
|
||||
|
||||
if info.IsNestedStruct {
|
||||
query, err = c.buildSelectQueryForNestedStructs(dialect, structType, info)
|
||||
query, err = buildSelectQueryForNestedStructs(dialect, structType, info)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -1129,7 +1121,7 @@ func buildSelectQueryForPlainStructs(
|
|||
return "SELECT " + strings.Join(fields, ", ") + " "
|
||||
}
|
||||
|
||||
func (c DB) buildSelectQueryForNestedStructs(
|
||||
func buildSelectQueryForNestedStructs(
|
||||
dialect Dialect,
|
||||
structType reflect.Type,
|
||||
info structs.StructInfo,
|
||||
|
@ -1150,7 +1142,7 @@ func (c DB) buildSelectQueryForNestedStructs(
|
|||
)
|
||||
}
|
||||
|
||||
nestedStructTagInfo, err := c.tagInfoCache.GetTagInfo(nestedStructType)
|
||||
nestedStructTagInfo, err := structs.GetTagInfo(nestedStructType)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/vingarcia/ksql/internal/structs"
|
||||
tt "github.com/vingarcia/ksql/internal/testtools"
|
||||
"github.com/vingarcia/ksql/nullable"
|
||||
)
|
||||
|
@ -2548,7 +2547,7 @@ func ScanRowsTest(
|
|||
tt.AssertEqual(t, rows.Next(), true)
|
||||
|
||||
var u user
|
||||
err = c.scanRows(dialect, rows, &u)
|
||||
err = scanRows(dialect, rows, &u)
|
||||
tt.AssertNoErr(t, err)
|
||||
|
||||
tt.AssertEqual(t, u.Name, "User2")
|
||||
|
@ -2581,7 +2580,7 @@ func ScanRowsTest(
|
|||
// Omitted for testing purposes:
|
||||
// Name string `ksql:"name"`
|
||||
}
|
||||
err = c.scanRows(dialect, rows, &u)
|
||||
err = scanRows(dialect, rows, &u)
|
||||
tt.AssertNoErr(t, err)
|
||||
|
||||
tt.AssertEqual(t, u.Age, 22)
|
||||
|
@ -2597,7 +2596,6 @@ func ScanRowsTest(
|
|||
ctx := context.TODO()
|
||||
db, closer := newDBAdapter(t)
|
||||
defer closer.Close()
|
||||
c := newTestDB(db, driver)
|
||||
|
||||
rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User2'")
|
||||
tt.AssertNoErr(t, err)
|
||||
|
@ -2605,7 +2603,7 @@ func ScanRowsTest(
|
|||
var u user
|
||||
err = rows.Close()
|
||||
tt.AssertNoErr(t, err)
|
||||
err = c.scanRows(dialect, rows, &u)
|
||||
err = scanRows(dialect, rows, &u)
|
||||
tt.AssertNotEqual(t, err, nil)
|
||||
})
|
||||
|
||||
|
@ -2619,14 +2617,13 @@ func ScanRowsTest(
|
|||
ctx := context.TODO()
|
||||
db, closer := newDBAdapter(t)
|
||||
defer closer.Close()
|
||||
c := newTestDB(db, driver)
|
||||
|
||||
rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User2'")
|
||||
tt.AssertNoErr(t, err)
|
||||
defer rows.Close()
|
||||
|
||||
var u user
|
||||
err = c.scanRows(dialect, rows, u)
|
||||
err = scanRows(dialect, rows, u)
|
||||
tt.AssertErrContains(t, err, "ksql", "expected", "pointer to struct", "user")
|
||||
})
|
||||
|
||||
|
@ -2640,14 +2637,13 @@ func ScanRowsTest(
|
|||
ctx := context.TODO()
|
||||
db, closer := newDBAdapter(t)
|
||||
defer closer.Close()
|
||||
c := newTestDB(db, driver)
|
||||
|
||||
rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User2'")
|
||||
tt.AssertNoErr(t, err)
|
||||
defer rows.Close()
|
||||
|
||||
var u map[string]interface{}
|
||||
err = c.scanRows(dialect, rows, &u)
|
||||
err = scanRows(dialect, rows, &u)
|
||||
tt.AssertErrContains(t, err, "ksql", "expected", "pointer to struct", "map[string]interface")
|
||||
})
|
||||
})
|
||||
|
@ -2780,8 +2776,6 @@ func newTestDB(db DBAdapter, driver string) DB {
|
|||
driver: driver,
|
||||
dialect: supportedDialects[driver],
|
||||
db: db,
|
||||
|
||||
tagInfoCache: structs.TagInfoCache{},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue